commit 159cf9fcfe75f94a20aaf8eee1daf08401c10312 Author: DigiJ Date: Fri Mar 13 12:56:43 2026 -0700 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8cce592 --- /dev/null +++ b/.gitignore @@ -0,0 +1,217 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +.venv1/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be added to the global gitignore or merged into this project gitignore. For a PyCharm +# project, it is generally recommended to not ignore these files. +.idea/ + +# VS Code +.vscode/ + +# Model files and checkpoints +*.bin +*.safetensors +*.gguf +*.ggml +*.pth +*.pt +*.ckpt +*.pkl +*.pickle + +# Large data files +*.csv +*.json +*.txt +*.log + +# API Keys and secrets +*.key +*.env +config.yaml +secrets.yaml +api_keys.json + +# Temporary files +temp/ +tmp/ +*.tmp +*.temp + +# OS generated files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# DarkHal specific ignores +logs/ +output/ +models/ +checkpoints/ +cache/ + +# Project-specific directories to ignore at repo root +/experimental_agent/ +/llama-cpp-python/ +/llm_chess-main/ +/Darkhal/ +/darkhall.egg-info/ +/.claude/ \ No newline at end of file diff --git a/0.19.0 b/0.19.0 new file mode 100644 index 0000000..e786e9c --- /dev/null +++ b/0.19.0 @@ -0,0 +1 @@ + - python-dotenv diff --git a/0.2.80 b/0.2.80 new file mode 100644 index 0000000..4a48114 --- /dev/null +++ b/0.2.80 @@ -0,0 +1 @@ + - llama-cpp-python diff --git a/0.24.0 b/0.24.0 new file mode 100644 index 0000000..b7e2c7e --- /dev/null +++ b/0.24.0 @@ -0,0 +1 @@ + - huggingface_hub diff --git a/1.21.0 b/1.21.0 new file mode 100644 index 0000000..10ca255 --- /dev/null +++ b/1.21.0 @@ -0,0 +1 @@ + - numpy diff --git a/2.32.0 b/2.32.0 new file mode 100644 index 0000000..b9674d9 --- /dev/null +++ b/2.32.0 @@ -0,0 +1 @@ + - requests diff --git a/5.8.0 b/5.8.0 new file mode 100644 index 0000000..b790799 --- /dev/null +++ b/5.8.0 @@ -0,0 +1 @@ + - psutil diff --git a/9.0.0 b/9.0.0 new file mode 100644 index 0000000..4170d2b --- /dev/null +++ b/9.0.0 @@ -0,0 +1 @@ + - Pillow diff --git a/=0.19.0 b/=0.19.0 new file mode 100644 index 0000000..e69de29 diff --git a/=0.2.3 b/=0.2.3 new file mode 100644 index 0000000..5699e88 --- /dev/null +++ b/=0.2.3 @@ -0,0 +1,48 @@ +Collecting auto-gptq + Downloading auto_gptq-0.7.1.tar.gz (126 kB) + Preparing metadata (setup.py): started + Preparing metadata (setup.py): finished with status 'done' +Discarding https://files.pythonhosted.org/packages/90/e5/b22697903982284fe284568fb2663a2196694a8eee637f5cf4ccfe435a38/auto_gptq-0.7.1.tar.gz (from https://pypi.org/simple/auto-gptq/) (requires-python:>=3.8.0): Requested auto-gptq from https://files.pythonhosted.org/packages/90/e5/b22697903982284fe284568fb2663a2196694a8eee637f5cf4ccfe435a38/auto_gptq-0.7.1.tar.gz has inconsistent version: expected '0.7.1', but metadata has '0.7.1+cu118' + Downloading auto_gptq-0.7.0.tar.gz (124 kB) + Preparing metadata (setup.py): started + Preparing metadata (setup.py): finished with status 'done' +Discarding https://files.pythonhosted.org/packages/34/71/c3e73cf17681f6ff4754ef8f4cb8b67af3def230fc8711eac1250bbd78d5/auto_gptq-0.7.0.tar.gz (from https://pypi.org/simple/auto-gptq/) (requires-python:>=3.8.0): Requested auto-gptq from https://files.pythonhosted.org/packages/34/71/c3e73cf17681f6ff4754ef8f4cb8b67af3def230fc8711eac1250bbd78d5/auto_gptq-0.7.0.tar.gz has inconsistent version: expected '0.7.0', but metadata has '0.7.0+cu118' + Downloading auto_gptq-0.6.0.tar.gz (120 kB) + Preparing metadata (setup.py): started + Preparing metadata (setup.py): finished with status 'done' +Discarding https://files.pythonhosted.org/packages/49/af/02b66e55dfd9aeb0ece923843043724ed7432ec0c649ea0f3b9fa1dd90c6/auto_gptq-0.6.0.tar.gz (from https://pypi.org/simple/auto-gptq/) (requires-python:>=3.8.0): Requested auto-gptq from https://files.pythonhosted.org/packages/49/af/02b66e55dfd9aeb0ece923843043724ed7432ec0c649ea0f3b9fa1dd90c6/auto_gptq-0.6.0.tar.gz has inconsistent version: expected '0.6.0', but metadata has '0.6.0+cu118' + Downloading auto_gptq-0.5.1.tar.gz (112 kB) + Preparing metadata (setup.py): started + Preparing metadata (setup.py): finished with status 'done' +Discarding https://files.pythonhosted.org/packages/db/77/ec5a16c5625b0791dccfe5e42356171332ed3537c1df505d64a162148c8f/auto_gptq-0.5.1.tar.gz (from https://pypi.org/simple/auto-gptq/) (requires-python:>=3.8.0): Requested auto-gptq from https://files.pythonhosted.org/packages/db/77/ec5a16c5625b0791dccfe5e42356171332ed3537c1df505d64a162148c8f/auto_gptq-0.5.1.tar.gz has inconsistent version: expected '0.5.1', but metadata has '0.5.1+cu118' + Downloading auto_gptq-0.5.0.tar.gz (111 kB) + Preparing metadata (setup.py): started + Preparing metadata (setup.py): finished with status 'done' +Discarding https://files.pythonhosted.org/packages/3d/fa/c2cd09965b2dbf4e454d9f073376922f7139a574f617f70a22adb203eced/auto_gptq-0.5.0.tar.gz (from https://pypi.org/simple/auto-gptq/) (requires-python:>=3.8.0): Requested auto-gptq from https://files.pythonhosted.org/packages/3d/fa/c2cd09965b2dbf4e454d9f073376922f7139a574f617f70a22adb203eced/auto_gptq-0.5.0.tar.gz has inconsistent version: expected '0.5.0', but metadata has '0.5.0+cu118' + Downloading auto_gptq-0.3.2.tar.gz (63 kB) + Preparing metadata (setup.py): started + Preparing metadata (setup.py): finished with status 'done' +Discarding https://files.pythonhosted.org/packages/1b/79/5a3a7d877a9b0a72f528e9977ec65cdb9fad800fa4f5110f87f2acaaf6fe/auto_gptq-0.3.2.tar.gz (from https://pypi.org/simple/auto-gptq/) (requires-python:>=3.8.0): Requested auto-gptq from https://files.pythonhosted.org/packages/1b/79/5a3a7d877a9b0a72f528e9977ec65cdb9fad800fa4f5110f87f2acaaf6fe/auto_gptq-0.3.2.tar.gz has inconsistent version: expected '0.3.2', but metadata has '0.3.2+cu118' + Downloading auto_gptq-0.3.1.tar.gz (63 kB) + Preparing metadata (setup.py): started + Preparing metadata (setup.py): finished with status 'done' +Collecting autoawq + Downloading autoawq-0.2.9.tar.gz (74 kB) + Preparing metadata (setup.py): started + Preparing metadata (setup.py): finished with status 'done' +Collecting exllamav2 + Downloading exllamav2-0.3.2-py3-none-any.whl.metadata (430 bytes) +Requirement already satisfied: accelerate>=0.19.0 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from auto-gptq) (1.10.0) +Collecting datasets (from auto-gptq) + Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB) +Requirement already satisfied: numpy in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from auto-gptq) (2.3.2) +Collecting rouge (from auto-gptq) + Downloading rouge-1.0.1-py3-none-any.whl.metadata (4.1 kB) +Requirement already satisfied: torch>=1.13.0 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from auto-gptq) (2.7.1+cu118) +Requirement already satisfied: safetensors in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from auto-gptq) (0.6.2) +Requirement already satisfied: transformers>=4.31.0 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from auto-gptq) (4.55.2) +Collecting peft (from auto-gptq) + Downloading peft-0.17.0-py3-none-any.whl.metadata (14 kB) +INFO: pip is looking at multiple versions of autoawq to determine which version is compatible with other requirements. This could take a while. +Collecting autoawq + Downloading autoawq-0.2.8.tar.gz (71 kB) diff --git a/=0.2.5 b/=0.2.5 new file mode 100644 index 0000000..e69de29 diff --git a/=0.33.0 b/=0.33.0 new file mode 100644 index 0000000..e69de29 diff --git a/=0.4.2 b/=0.4.2 new file mode 100644 index 0000000..fe620fe --- /dev/null +++ b/=0.4.2 @@ -0,0 +1,35 @@ +Collecting transformers + Using cached transformers-4.55.2-py3-none-any.whl.metadata (41 kB) +Collecting accelerate + Using cached accelerate-1.10.0-py3-none-any.whl.metadata (19 kB) +Requirement already satisfied: safetensors in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (0.6.2) +Requirement already satisfied: filelock in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from transformers) (3.19.1) +Requirement already satisfied: huggingface-hub<1.0,>=0.34.0 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from transformers) (0.34.4) +Requirement already satisfied: numpy>=1.17 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from transformers) (2.3.2) +Requirement already satisfied: packaging>=20.0 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from transformers) (25.0) +Requirement already satisfied: pyyaml>=5.1 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from transformers) (6.0.2) +Requirement already satisfied: regex!=2019.12.17 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from transformers) (2025.7.34) +Requirement already satisfied: requests in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from transformers) (2.32.5) +Requirement already satisfied: tokenizers<0.22,>=0.21 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from transformers) (0.21.4) +Requirement already satisfied: tqdm>=4.27 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from transformers) (4.67.1) +Requirement already satisfied: fsspec>=2023.5.0 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (2025.7.0) +Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (4.14.1) +Requirement already satisfied: psutil in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from accelerate) (7.0.0) +Requirement already satisfied: torch>=2.0.0 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from accelerate) (2.7.1+cu118) +Requirement already satisfied: sympy>=1.13.3 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from torch>=2.0.0->accelerate) (1.13.3) +Requirement already satisfied: networkx in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from torch>=2.0.0->accelerate) (3.5) +Requirement already satisfied: jinja2 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from torch>=2.0.0->accelerate) (3.1.6) +Requirement already satisfied: setuptools in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from torch>=2.0.0->accelerate) (80.9.0) +Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from sympy>=1.13.3->torch>=2.0.0->accelerate) (1.3.0) +Requirement already satisfied: colorama in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from tqdm>=4.27->transformers) (0.4.6) +Requirement already satisfied: MarkupSafe>=2.0 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from jinja2->torch>=2.0.0->accelerate) (3.0.2) +Requirement already satisfied: charset_normalizer<4,>=2 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from requests->transformers) (3.4.3) +Requirement already satisfied: idna<4,>=2.5 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from requests->transformers) (3.10) +Requirement already satisfied: urllib3<3,>=1.21.1 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from requests->transformers) (2.5.0) +Requirement already satisfied: certifi>=2017.4.17 in c:\users\mdavi\pycharmprojects\llm_train\.venv\lib\site-packages (from requests->transformers) (2025.8.3) +Downloading transformers-4.55.2-py3-none-any.whl (11.3 MB) + ---------------------------------------- 11.3/11.3 MB 39.1 MB/s 0:00:00 +Downloading accelerate-1.10.0-py3-none-any.whl (374 kB) +Installing collected packages: accelerate, transformers + +Successfully installed accelerate-1.10.0 transformers-4.55.2 diff --git a/=0.7.1 b/=0.7.1 new file mode 100644 index 0000000..e69de29 diff --git a/=2.0.0 b/=2.0.0 new file mode 100644 index 0000000..dace960 --- /dev/null +++ b/=2.0.0 @@ -0,0 +1 @@ +Looking in indexes: https://download.pytorch.org/whl/cu121 diff --git a/=4.43.0 b/=4.43.0 new file mode 100644 index 0000000..e69de29 diff --git a/DEPENDENCY_INSTALLER_README.md b/DEPENDENCY_INSTALLER_README.md new file mode 100644 index 0000000..7530246 --- /dev/null +++ b/DEPENDENCY_INSTALLER_README.md @@ -0,0 +1,161 @@ +# Windows Dependency Installer for LLM_Train + +This dependency installer helps you automatically install common software packages required for LLM_Train on Windows using the Chocolatey package manager. + +## Quick Start + +### Option 1: Batch File (Recommended) +1. Double-click `install_dependencies.bat` +2. Grant administrator permissions when prompted +3. Select packages to install in the GUI + +### Option 2: PowerShell Script +1. Right-click `install_dependencies.ps1` → "Run with PowerShell" +2. Grant administrator permissions when prompted +3. Select packages to install in the GUI + +### Option 3: Direct Python Execution +```bash +# Run as administrator +python windows_dependency_installer.py +``` + +## What Gets Installed + +### Essential Packages (Auto-selected) +- **Git** - Version control system (required for repository cloning) +- **Python 3** - Python programming language (if not already installed) +- **7-Zip** - File archiver for extracting downloads +- **Visual C++ Redistributables** - Microsoft runtime libraries + +### Development Tools +- **Visual Studio Code** - Advanced code editor with Python support +- **Notepad++** - Enhanced text editor + +### GPU Acceleration +- **CUDA Toolkit** - NVIDIA CUDA development toolkit for GPU acceleration +- **NVIDIA Display Driver** - Latest NVIDIA graphics drivers + +### System Utilities +- **Wget** - Command-line downloader +- **cURL** - Data transfer tool +- **PowerToys** - Windows system utilities + +### Runtimes +- **.NET Runtime** - Microsoft .NET framework + +### Optional Tools +- **WinRAR** - Alternative file archiver +- **Firefox** - Web browser +- **VLC Media Player** - Media player + +## System Requirements + +- **Windows 10/11** (Windows 8.1 may work but is not tested) +- **Administrator privileges** (required for Chocolatey and package installation) +- **Internet connection** (for downloading packages) +- **Python 3.7+** (for running the installer GUI) + +## Features + +### Chocolatey Integration +- Automatically installs Chocolatey if not present +- Uses Chocolatey's robust package management +- Handles dependencies automatically + +### Smart Package Selection +- **Select Essential** - Chooses only required packages +- **Select All** - Selects all available packages +- **Custom Selection** - Pick individual packages + +### Installation Monitoring +- Real-time installation log +- Progress tracking +- Success/failure reporting +- Package status checking + +### System Status Checks +- Administrator privilege detection +- Chocolatey installation status +- Individual package installation status + +## Troubleshooting + +### "Python not found" Error +1. Install Python from https://python.org/downloads/ +2. During installation, check "Add Python to PATH" +3. Restart your command prompt/PowerShell + +### "Administrator privileges required" Error +1. Right-click the batch file → "Run as administrator" +2. Or open Command Prompt as administrator and run manually + +### "Chocolatey installation failed" Error +1. Ensure you're running as administrator +2. Check your internet connection +3. Temporarily disable antivirus software during installation +4. Check Windows execution policy: `Set-ExecutionPolicy RemoteSigned` + +### Package Installation Failures +1. Check the installation log for specific error messages +2. Try installing packages individually +3. Ensure sufficient disk space +4. Check for conflicting software + +### Network/Firewall Issues +1. Ensure Chocolatey URLs are not blocked: + - https://community.chocolatey.org/ + - https://packages.chocolatey.org/ +2. Configure proxy settings if behind corporate firewall +3. Temporarily disable firewall/antivirus + +## Manual Installation + +If the automatic installer fails, you can install Chocolatey manually: + +1. Open PowerShell as Administrator +2. Run: + ```powershell + Set-ExecutionPolicy Bypass -Scope Process -Force; + [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; + iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) + ``` +3. Then install packages manually: + ```powershell + choco install git python 7zip vcredist-all -y + choco install cuda nvidia-display-driver -y # For GPU support + ``` + +## Package Descriptions + +### Why Each Package? + +- **Git**: Required for cloning model repositories and version control +- **Python**: Core runtime for LLM_Train (if system Python is outdated) +- **7-Zip**: Many model files are compressed and need extraction +- **Visual C++ Redistributables**: Required by many Python packages and binaries +- **CUDA Toolkit**: Enables GPU acceleration for faster model inference +- **NVIDIA Drivers**: Latest drivers for optimal GPU performance +- **Visual Studio Code**: Best IDE for Python development and debugging +- **Wget/cURL**: Alternative download tools for model files +- **PowerToys**: Useful Windows utilities for power users + +## Security Notes + +- All packages are installed from official Chocolatey community repository +- Chocolatey packages are maintained by the community and Microsoft +- Administrator privileges are required only for system-wide installation +- No personal data is collected or transmitted + +## Support + +If you encounter issues: + +1. Check the installation log for error messages +2. Search for the specific error on Chocolatey community forums +3. Try installing individual packages manually +4. Ensure your Windows is up to date + +## License + +This installer uses Chocolatey (Apache 2.0 License) and installs various packages with their respective licenses. Please review individual package licenses as needed. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..3624382 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,50 @@ +# Include the README and license files +include README.md +include *.txt +include *.md + +# Include configuration files +include pyproject.toml +include requirements.txt + +# Include everything in darkhal package directory +recursive-include darkhal * + +# Include any JSON configuration files at root +include *.json + +# Include environment files +include *.env + +# Include any shell scripts that might be useful +include *.sh +include *.bat + +# Exclude development and build files +exclude .gitignore +exclude *.pyc +exclude __pycache__ +recursive-exclude * __pycache__ +recursive-exclude * *.py[co] +recursive-exclude * *.orig +recursive-exclude * *.rej +exclude .coverage +exclude .tox +recursive-exclude build * +recursive-exclude dist * +recursive-exclude .git * +recursive-exclude win-install * +recursive-exclude debian-install * +recursive-exclude tests * +recursive-exclude llm_chess-main * +recursive-exclude temp_llm_chess * +recursive-exclude llama-cpp-python * +recursive-exclude downloads * +recursive-exclude models * +recursive-exclude games * + +# Exclude temporary and cache files +exclude .DS_Store +recursive-exclude * .DS_Store +exclude Thumbs.db +recursive-exclude * Thumbs.db \ No newline at end of file diff --git a/MANUAL.md b/MANUAL.md new file mode 100644 index 0000000..97a01d3 --- /dev/null +++ b/MANUAL.md @@ -0,0 +1,704 @@ +# DarkHal 2.0 User Manual + +## Table of Contents + +1. [Introduction](#introduction) +2. [Getting Started](#getting-started) +3. [Main Interface](#main-interface) +4. [Model Management](#model-management) +5. [Agent Mode](#agent-mode) +6. [Advanced Features](#advanced-features) +7. [Troubleshooting](#troubleshooting) +8. [Keyboard Shortcuts](#keyboard-shortcuts) +9. [Command Reference](#command-reference) +10. [FAQ](#faq) + +--- + +## Introduction + +DarkHal 2.0 is an advanced AI model management platform that provides comprehensive tools for loading, running, and interacting with Large Language Models (LLMs). This manual covers all features and capabilities of the platform. + +### System Requirements + +**Minimum Requirements:** +- Windows 10/11, Linux (Ubuntu 20.04+), or macOS 11+ +- Python 3.8 or higher +- 8GB RAM +- 20GB free disk space +- Internet connection for downloading models + +**Recommended Requirements:** +- 16GB+ RAM +- NVIDIA GPU with 8GB+ VRAM +- 100GB+ free disk space for models +- High-speed internet for model downloads + +--- + +## Getting Started + +### First Launch + +1. **Start DarkHal**: + ```bash + python main.py --gui + ``` + +2. **Initial Setup**: + - The splash screen will appear showing system information + - Configure your models directory via Settings menu + - Set up HuggingFace API token if you plan to download models + +3. **Load Your First Model**: + - Click "Browse Model" to select a model file + - Or use "Browse Folder" to select a model directory + - Click "Load Model" to initialize + +### Understanding the Interface + +The main window consists of several tabs: +- **Run**: Chat interface and model loading +- **Model Library**: Browse and manage local models +- **HuggingFace**: Download models from HuggingFace Hub +- **Downloads**: Monitor active downloads +- **MCP**: Model Context Protocol server +- **Converter**: Convert between model formats +- **Chess**: Specialized chess engine interface + +--- + +## Main Interface + +### Run Tab + +The Run tab is your primary workspace for interacting with models. + +#### Model Selection Panel + +**Model Path Input**: +- Enter the full path to your model file or directory +- Supports drag-and-drop from file explorer +- Auto-completes recently used models + +**Browse Model Button**: +- Opens file dialog to select model files +- Filters: GGUF, SafeTensors, PyTorch, GPTQ, AWQ, EXL2 +- Shows all supported formats + +**Browse Folder Button**: +- Select directories containing model files +- Useful for HuggingFace format models +- Auto-detects config.json + +**Load Model Button**: +- Initializes the selected model +- Shows loading progress +- Displays model information when complete + +#### Chat Interface + +**Chat Mode Options**: +- **Stream Output**: Shows text as it's generated +- **Chess Mode**: Enables ChessGPT for chess moves +- **Agent Mode**: Enables system command execution + +**Input Area**: +- Multi-line text input +- Supports Ctrl+Enter for sending +- Maintains conversation history + +**Output Display**: +- Shows conversation with "You:" and "Assistant:" prefixes +- Auto-scrolls to latest message +- Supports text selection and copying + +**Control Buttons**: +- **Send (Chat)**: Submit your message +- **Stop**: Interrupt generation +- **Clear Output**: Clear conversation display +- **Clear History**: Reset conversation context + +### Model Settings Tab + +#### Basic Settings + +**Context Size (n_ctx)**: +- Range: 512 to 32768 tokens +- Default: 4096 +- Higher values use more memory but allow longer conversations + +**GPU Layers**: +- Range: 0 to model layer count +- 0 = CPU only +- Higher values offload more to GPU + +**Max Tokens**: +- Maximum tokens to generate +- Range: 1 to context size +- Default: 2048 + +#### Advanced Loading Options + +**Quantization**: +- `none`: Full precision (FP16/FP32) +- `4bit`: ~75% memory savings +- `8bit`: ~50% memory savings +- `gptq`: Pre-quantized GPTQ format +- `awq`: Pre-quantized AWQ format +- `exl2`: Pre-quantized EXL2 format + +**Device Strategy**: +- `auto`: Automatic distribution +- `force_gpu`: All layers on GPU +- `balanced_split`: Split between CPU/GPU +- `cpu_only`: CPU processing only + +**GPU Memory Limit**: +- Maximum VRAM to use (in GB) +- Used with balanced_split strategy +- Prevents out-of-memory errors + +#### Sampling Parameters + +**Temperature** (0.0 - 2.0): +- Controls randomness +- 0.0 = Deterministic +- 0.7 = Balanced (default) +- 1.5+ = Very creative + +**Top-p** (0.0 - 1.0): +- Nucleus sampling threshold +- 0.9 = Default +- Lower values = More focused + +**Repetition Penalty** (1.0 - 2.0): +- Reduces repetitive text +- 1.0 = No penalty +- 1.1 = Light penalty (default) + +**Min-p** (0.0 - 1.0): +- Minimum probability threshold +- 0.0 = Disabled (default) + +**Typical-p** (0.0 - 1.0): +- Typical sampling threshold +- 1.0 = Disabled (default) + +--- + +## Model Management + +### Supported Formats + +DarkHal 2.0 supports multiple model formats: + +| Format | Extension | Use Case | Pros | Cons | +|--------|-----------|----------|------|------| +| **GGUF** | `.gguf` | CPU/GPU hybrid | Fast loading, efficient | Limited to llama.cpp models | +| **SafeTensors** | `.safetensors` | HuggingFace models | Secure, fast | Larger file sizes | +| **PyTorch** | `.bin`, `.pt`, `.pth` | Research models | Flexible | Slower loading | +| **GPTQ** | `*gptq*.safetensors` | GPU inference | 4-bit quantized | GPU required | +| **AWQ** | `*awq*.safetensors` | GPU inference | Optimized quantization | GPU required | +| **EXL2** | `.exl2` | ExLlamaV2 | Very fast | Specific hardware needs | + +### Model Library Tab + +The Model Library provides comprehensive model management: + +**Features**: +- Automatic scanning of model directories +- Metadata extraction (parameters, architecture) +- Search by name, type, or tags +- Size and modification date display +- One-click loading + +**Using the Library**: +1. Set your models directory in Settings +2. Click "Scan" to index models +3. Use search box to filter +4. Double-click to load model + +### Downloading Models + +#### HuggingFace Tab + +**Search and Browse**: +- Enter model name or organization +- Browse trending models +- Filter by task type +- View model cards + +**Download Process**: +1. Enter model ID (e.g., "meta-llama/Llama-2-7b") +2. Click "Get File List" +3. Select files to download +4. Click "Start Download" +5. Monitor progress in Downloads tab + +**File Selection**: +- Use checkboxes to select specific files +- "Select All" for complete model +- Size estimates shown for each file +- Automatic resume on failure + +#### Downloads Tab + +**Download Management**: +- Grouped display by model +- Individual file progress +- Speed and time estimates +- Pause/resume capability +- Automatic retry on failure + +**Controls**: +- Collapse/expand model groups +- Cancel individual files +- Clear completed downloads +- Set bandwidth limits + +--- + +## Agent Mode + +### ⚠️ WARNING + +Agent Mode grants the AI unrestricted system access. Only enable with trusted models and full understanding of risks. + +### Enabling Agent Mode + +1. Load any model +2. Check "🤖 Agent Mode (SYSTEM ACCESS)" +3. Confirm security warning +4. Agent mode indicator shows "ACTIVE" + +### Capabilities + +**System Control**: +- Execute shell commands +- Run PowerShell scripts +- Execute Bash commands +- Launch applications + +**File Operations**: +- Read any file +- Write/create files +- Delete files +- List directories + +**Application Control**: +- Open programs (Word, Notepad, etc.) +- Control mouse movement +- Send keyboard input +- Automate workflows + +**Programming**: +- Execute Python code +- Run scripts +- Install packages +- Compile code + +### Example Commands + +**Opening Applications**: +``` +"Open PowerShell" +"Launch Microsoft Word" +"Start notepad" +"Open calculator" +``` + +**File Operations**: +``` +"List files in current directory" +"Create a file called test.txt with 'Hello World'" +"Read the contents of config.json" +"Delete temporary files" +``` + +**System Commands**: +``` +"Show system information" +"Check disk space" +"List running processes" +"Create a new folder called Projects" +``` + +**Document Creation**: +``` +"Open Word and create a document about Python" +"Create an Apache server setup guide" +"Write a bash script to backup files" +``` + +### Safety Guidelines + +1. **Review Commands**: Always review AI-generated commands before execution +2. **Backup Data**: Keep backups before allowing file operations +3. **Limit Scope**: Use specific requests rather than broad permissions +4. **Monitor Activity**: Watch the output for unexpected behavior +5. **Disable When Done**: Turn off Agent Mode after use + +--- + +## Advanced Features + +### Chat Templates + +Chat templates format conversations for different model architectures. + +**Loading Templates**: +1. Click "Load" next to Chat Template dropdown +2. Select JSON file with templates +3. Choose template from dropdown + +**Adding Custom Templates**: +1. Click "Add" button +2. Define template format +3. Set special tokens +4. Save to templates file + +**Template Format**: +```json +{ + "name": "llama3", + "template": "<|begin_of_text|>{% for message in messages %}...", + "bos_token": "<|begin_of_text|>", + "eos_token": "<|eot_id|>" +} +``` + +### Model Conversion + +The Converter tab allows format transformation: + +**Supported Conversions**: +- GGUF → SafeTensors +- SafeTensors → GGUF +- PyTorch → GGUF +- GPTQ → GGUF + +**Conversion Process**: +1. Select source model +2. Choose target format +3. Set quantization options +4. Click "Convert" +5. Monitor progress + +### MCP Server + +Model Context Protocol enables remote access: + +**Starting Server**: +1. Go to MCP tab +2. Configure port and settings +3. Click "Start Server" +4. Note the connection URL + +**Connecting Clients**: +- Claude Desktop integration +- Remote control GUI +- Custom API clients +- Web interfaces + +**Available Endpoints**: +- `/list_models` - Available models +- `/load_model` - Load specific model +- `/generate` - Text generation +- `/chat` - Conversation mode + +### Chess Mode + +Specialized interface for chess AI: + +**Features**: +- FEN notation support +- Move generation +- Position evaluation +- Game analysis + +**Using Chess Mode**: +1. Enable "Chess Mode" checkbox +2. Enter position in FEN format +3. Request move analysis +4. Get UCI format moves + +--- + +## Troubleshooting + +### Common Issues + +#### Model Won't Load + +**Symptoms**: Error message when loading model + +**Solutions**: +- Verify file path is correct +- Check file isn't corrupted +- Ensure sufficient RAM/VRAM +- Try reducing GPU layers +- Lower context size + +#### Out of Memory + +**Symptoms**: Application crashes or freezes + +**Solutions**: +- Use quantized models (4-bit/8-bit) +- Reduce context size +- Lower GPU layers +- Use CPU-only mode +- Close other applications + +#### Slow Generation + +**Symptoms**: Very slow text generation + +**Solutions**: +- Enable GPU acceleration +- Increase GPU layers +- Use smaller models +- Reduce context size +- Check CPU/GPU usage + +#### Download Failures + +**Symptoms**: Downloads fail or hang + +**Solutions**: +- Check internet connection +- Verify HuggingFace token +- Clear download cache +- Use VPN if blocked +- Try different mirror + +### Error Messages + +**"CUDA out of memory"**: +- Reduce GPU layers +- Use smaller batch size +- Enable memory efficient attention +- Use quantized model + +**"Model file not found"**: +- Check file path +- Verify file exists +- Check permissions +- Try absolute path + +**"Invalid model format"**: +- Verify file format +- Check model compatibility +- Update DarkHal +- Try conversion + +**"Token limit exceeded"**: +- Reduce input length +- Lower max tokens +- Clear conversation history +- Increase context size + +--- + +## Keyboard Shortcuts + +### Global Shortcuts + +| Shortcut | Action | +|----------|--------| +| `Ctrl+N` | New conversation | +| `Ctrl+O` | Open model | +| `Ctrl+S` | Save conversation | +| `Ctrl+Q` | Quit application | +| `F1` | Open help | +| `F5` | Refresh model list | + +### Chat Interface + +| Shortcut | Action | +|----------|--------| +| `Ctrl+Enter` | Send message | +| `Ctrl+L` | Clear output | +| `Ctrl+H` | Clear history | +| `Esc` | Stop generation | +| `Ctrl+C` | Copy selected text | +| `Ctrl+A` | Select all | + +### Model Library + +| Shortcut | Action | +|----------|--------| +| `Ctrl+F` | Focus search | +| `Enter` | Load selected model | +| `Delete` | Remove from library | +| `F5` | Rescan directory | + +--- + +## Command Reference + +### CLI Arguments + +```bash +python main.py [options] +``` + +**Options**: +- `--gui` - Launch GUI mode (default) +- `--model PATH` - Model file path +- `--prompt TEXT` - Initial prompt +- `--stream` - Enable streaming +- `--n_ctx N` - Context size +- `--n_gpu_layers N` - GPU layers +- `--lora PATH` - LoRA adapter path + +### Configuration Files + +**settings.json**: +```json +{ + "paths": { + "models_directory": "./models", + "download_directory": "./downloads" + }, + "model_settings": { + "default_n_ctx": 4096, + "default_n_gpu_layers": 0, + "stream_by_default": true, + "temperature": 0.7, + "top_p": 0.9, + "repetition_penalty": 1.1 + } +} +``` + +**HUGGINGFACE.env**: +``` +HF_API_KEY=your_token_here +HF_HOME=./models/huggingface +``` + +--- + +## FAQ + +### General Questions + +**Q: What models work with DarkHal?** +A: Any model in GGUF, SafeTensors, PyTorch, GPTQ, AWQ, or EXL2 format. Most HuggingFace models are compatible. + +**Q: How much RAM do I need?** +A: Depends on model size. 7B models need ~8GB, 13B need ~16GB, 70B need ~64GB. Quantization reduces requirements. + +**Q: Can I run without GPU?** +A: Yes, CPU-only mode works but is slower. Use GGUF models with 0 GPU layers. + +**Q: Is my data private?** +A: Yes, all processing is local. No data is sent to external servers unless using HuggingFace downloads. + +### Model Questions + +**Q: What's the difference between formats?** +A: GGUF is optimized for CPU/GPU hybrid. SafeTensors is HuggingFace standard. GPTQ/AWQ/EXL2 are quantized for GPU. + +**Q: How do I choose quantization?** +A: 4-bit saves most memory with slight quality loss. 8-bit balances quality and size. None uses full precision. + +**Q: Why is generation slow?** +A: Check GPU usage, reduce context size, use quantized models, or enable more GPU layers. + +**Q: Can I use multiple models?** +A: One model at a time in current version. Switch models by loading different ones. + +### Agent Mode Questions + +**Q: Is Agent Mode safe?** +A: Agent Mode grants full system access. Only use with trusted models and review commands. + +**Q: What can Agent Mode do?** +A: Execute any system command, control applications, manage files, run code, automate tasks. + +**Q: How do I limit Agent Mode?** +A: Currently all-or-nothing. Future versions will have granular permissions. + +**Q: Can Agent Mode access internet?** +A: Yes, through system commands like curl or wget, and Python's requests library. + +### Troubleshooting Questions + +**Q: Download keeps failing?** +A: Check internet, verify HF token, try VPN, clear cache, or download manually. + +**Q: Model won't load?** +A: Verify path, check format, ensure enough memory, try different quantization. + +**Q: Getting CUDA errors?** +A: Update GPU drivers, check CUDA version, reduce GPU layers, or use CPU mode. + +**Q: Application crashes?** +A: Check error logs, reduce memory usage, update dependencies, file bug report. + +--- + +## Support + +### Getting Help + +**Documentation**: [https://darkhal.readthedocs.io](https://darkhal.readthedocs.io) +**GitHub Issues**: [https://github.com/darkhal/issues](https://github.com/darkhal/issues) +**Discussions**: [https://github.com/darkhal/discussions](https://github.com/darkhal/discussions) +**Email Support**: support@darkhal.ai + +### Reporting Bugs + +Include: +1. System information (OS, GPU, RAM) +2. Model details (format, size, source) +3. Error messages and logs +4. Steps to reproduce +5. Screenshots if applicable + +### Contributing + +We welcome contributions! See CONTRIBUTING.md for guidelines. + +--- + +## Appendices + +### A. Model Compatibility Matrix + +| Model Family | GGUF | SafeTensors | GPTQ | AWQ | EXL2 | +|--------------|------|-------------|------|-----|------| +| Llama 2/3 | ✅ | ✅ | ✅ | ✅ | ✅ | +| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | +| Mixtral | ✅ | ✅ | ✅ | ✅ | ❌ | +| Qwen | ✅ | ✅ | ✅ | ✅ | ✅ | +| Yi | ✅ | ✅ | ✅ | ✅ | ✅ | +| Gemma | ✅ | ✅ | ❌ | ❌ | ❌ | + +### B. Performance Benchmarks + +| Model | Format | GPU | Speed (tok/s) | Memory | +|-------|--------|-----|---------------|---------| +| Llama-7B | GGUF Q4 | RTX 3060 | 45 | 4.5GB | +| Llama-13B | GGUF Q4 | RTX 3090 | 35 | 8.5GB | +| Mistral-7B | GPTQ | RTX 4090 | 65 | 5.0GB | +| Mixtral-8x7B | AWQ | A100 | 25 | 24GB | + +### C. Glossary + +**Context Size**: Maximum tokens the model can process at once +**GPU Layers**: Model layers offloaded to GPU for acceleration +**Quantization**: Reducing model precision to save memory +**LoRA**: Low-Rank Adaptation for model fine-tuning +**Token**: Basic unit of text (roughly 0.75 words) +**VRAM**: Video RAM on graphics card +**Streaming**: Showing text as it's generated +**KV Cache**: Key-value cache for faster inference + +--- + +*DarkHal 2.0 User Manual - Version 2.0.0* +*Last Updated: January 2025* \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..934d63e --- /dev/null +++ b/README.md @@ -0,0 +1,441 @@ +If you would like to join, please email ssSnake@darkHal.org or sshhh@setecastronomy.gg. Please include a short bio and a link to 2-3 projects you have worked on. If you have no experience and would like to join, email me at ssSnake@darkHal.org and make your case, just because you dont have experience doesn't mean you dont have skills. + + +# DarkHal 2.0 🤖 + +
+ +![DarkHal 2.0 Logo](assets/logo.png) + +**Advanced AI Model Management & Training Platform** + +[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +[![Platform](https://img.shields.io/badge/platform-Windows%20%7C%20macOS%20%7C%20Linux-lightgrey)](https://github.com/your-org/darkhal) + +*Comprehensive solution for AI model downloading, management, training, and inference* + +
+ +## ⚠️ DISCLAIMER + +**This software is provided "as is" without any warranties or guarantees. The user assumes all responsibility for the use of this software and any consequences that may arise from its use. The developers are not liable for any damages, data loss, or other issues that may occur.** + +--- + +## 🚀 Features + +### 🔽 **Advanced Download Management** +- **Grouped Downloads**: Organize multi-file model downloads with collapsible widgets +- **HuggingFace Integration**: Direct downloading from HuggingFace Hub with authentication +- **Smart File Selection**: Choose specific files from model repositories +- **Progress Tracking**: Real-time download progress with speed monitoring +- **Resume Support**: Automatically resume interrupted downloads + +### 🤖 **AI Model Management** +- **Model Library**: Intelligent scanning and indexing of local model files +- **Multi-Format Support**: GGUF, SafeTensors, PyTorch, ONNX, and more +- **Metadata Extraction**: Automatic detection of model parameters and tags +- **Quick Search**: Fast searching by name, type, size, and tags +- **GPU Acceleration**: CUDA, ROCm, Metal, and Intel GPU support + +### 🎛️ **Inference & Chat** +- **Local Inference**: Run models locally without internet connection +- **Streaming Support**: Real-time text generation with streaming output +- **Chat Interface**: Interactive conversation mode with context memory +- **Parameter Control**: Adjustable temperature, context size, and token limits +- **LoRA Support**: Load and apply LoRA adapters to base models + +### 🌐 **Remote Control** +- **MCP Server**: Model Context Protocol server for remote management +- **Remote GUI**: Standalone application for controlling server remotely +- **Claude Integration**: Direct integration with Claude Desktop +- **API Access**: RESTful API for programmatic control + +### ⚙️ **System Integration** +- **Auto-Setup**: Automatic dependency installation and GPU detection +- **Windows Integration**: MSI installer with Windows-specific optimizations +- **Cross-Platform**: Native support for Windows, macOS, and Linux +- **Chocolatey Support**: Automated Windows dependency management + +--- + +## 📦 Installation + +### 🔧 **Quick Install (Recommended)** + +#### Windows (Coming Soon) +```bash +# Download and run the MSI installer +darkhal-2.0-setup.msi + +# Or install via Chocolatey +choco install darkhal +``` + +#### macOS/Linux (cooming soon) +```bash +# Install via pip +pip install darkhal + +# Or from source +git clone https://github.com/your-org/darkhal.git +cd darkhal +pip install -e . +``` + +### 🛠️ **Manual Installation** + +1. **Clone the repository** + ```bash + git clone https://github.com/your-org/darkhal.git + cd darkhal + ``` + +2. **Install dependencies** + ```bash + # Basic installation + pip install -r requirements.txt + + # With GPU support + pip install -r requirements.txt darkhal[gpu] + + # With audio/whisper support + pip install -r requirements.txt darkhal[audio] + ``` + +3. **Run dependency installer** (Windows) + ```bash + python windows_dependency_installer.py + ``` + +4. **Configure HuggingFace** (optional) + ```bash + # Create HUGGINGFACE.env file + echo "HF_API_KEY=your_token_here" > HUGGINGFACE.env + ``` + +--- + +## 🎯 Quick Start + +### 🖥️ **GUI Mode (Default)** +```bash +# Launch with splash screen +python main.py + +# Or use the installed command +darkhal +``` + +### 💻 **CLI Mode** +Full CLI command list coming soon. Right now the agent mode works best from a powershell, shell or bash. + + +```bash +# Interactive chat +python main.py --model path/to/model.gguf + +# Single prompt +python main.py --model path/to/model.gguf --prompt "Your question here" + +# With GPU acceleration +python main.py --model path/to/model.gguf --n_gpu_layers 32 +``` + +### 🌐 **Remote Control** +```bash +# Launch remote control GUI +python remotecontrol.py + +# Or use the installed command +darkhal-remote +``` + +### 🔌 **MCP Server** +```bash +# Start MCP server +python mcp_server.py + +# Or use the installed command +darkhal-mcp +``` + +--- + +## 📋 Usage Guide + +### 1. **First Launch** +- DarkHal 2.0 will show a splash screen with disclaimers +- Configure your models directory in Settings +- Set up HuggingFace authentication if needed +- Scan your model library for automatic indexing + +### 2. **Downloading Models** +- Go to the **HuggingFace** tab +- Search for models by name or tags +- Select desired files from multi-file models +- Downloads are organized in collapsible groups +- Monitor progress in the **Downloads** tab + +### 3. **Managing Models** +- Use the **Model Library** tab to browse local models +- Search by name, file type, or tags +- View detailed metadata and statistics +- Load models directly from the library + +### 4. **Running Inference** +- Select a model using "Browse" or the model library +- Configure context size and GPU layers +- Enter your prompt and click "Generate" +- Toggle streaming for real-time output +- Use chat mode for conversations + +### 5. **Remote Operations** +- Start the MCP server from the main application +- Launch the remote control GUI +- Connect to the server and manage models remotely +- Integrate with Claude Desktop for enhanced AI workflows + +--- + +## ⚙️ Configuration + +### 📁 **Settings Files** +- `settings.json` - Main application settings +- `HUGGINGFACE.env` - HuggingFace API credentials +- `mcp_config.json` - MCP server configuration +- `.model_index.json` - Model library index + +### 🎛️ **Key Settings** +```json +{ + "paths": { + "models_directory": "./models", + "download_directory": "./downloads" + }, + "model_settings": { + "default_n_ctx": 4096, + "default_n_gpu_layers": 0, + "stream_by_default": true + }, + "download_settings": { + "max_concurrent_downloads": 3, + "speed_limit_mbps": 0 + } +} +``` + +### 🖥️ **GPU Configuration** +DarkHal 2.0 automatically detects and optimizes for: +- **NVIDIA CUDA** (Windows/Linux) +- **AMD ROCm** (Linux) +- **Apple Metal** (macOS) +- **Intel GPU** (Windows/Linux) + +--- + +## 🔌 API Reference + +### **MCP Tools** +- `list_models` - Get available models +- `load_model` - Load a model with parameters +- `generate_text` - Generate text with prompt +- `get_system_info` - Get system capabilities + +### **Claude Integration** +```json +{ + "mcpServers": { + "darkhal": { + "command": "python", + "args": ["path/to/mcp_server.py"] + } + } +} +``` + +--- + +## 🏗️ Architecture + +### **Core Components** +``` +DarkHal 2.0/ +├── main.py # Main GUI application +├── splash_screen.py # Startup splash screen +├── remotecontrol.py # Remote control GUI +├── mcp_server.py # MCP protocol server +├── settings_manager.py # Configuration management +├── model_library.py # Model indexing & search +├── grouped_download_* # Advanced download system +└── assets/ # Icons and resources +``` + +### **Data Flow** +1. **User Interface** → Settings Manager → Model Operations +2. **Download Manager** → HuggingFace API → Local Storage +3. **Model Library** → File Scanner → Metadata Extractor +4. **MCP Server** → llama.cpp → Text Generation +5. **Remote Control** → MCP Client → Server Commands + +--- + +## 🛠️ Development + +### **Building from Source** +```bash +# Clone repository +git clone https://github.com/your-org/darkhal.git +cd darkhal + +# Install development dependencies +pip install -e .[dev] + +# Run tests +pytest tests/ + +# Build distribution +python setup.py sdist bdist_wheel + +# Build MSI installer (Windows) +python build_installer.py +``` + +### **Contributing** +1. Fork the repository +2. Create a feature branch +3. Make your changes +4. Add tests and documentation +5. Submit a pull request + +### **Code Style** +- Follow PEP 8 guidelines +- Use type hints where possible +- Add docstrings to all functions +- Include comprehensive error handling + +--- + +## 📊 Performance & Requirements + +### **System Requirements** +| Component | Minimum | Recommended | +|-----------|---------|-------------| +| **OS** | Windows 10, macOS 10.14, Ubuntu 18.04 | Latest versions | +| **Python** | 3.8+ | 3.11+ | +| **RAM** | 8 GB | 16+ GB | +| **Storage** | 10 GB free | 100+ GB for models | +| **GPU** | Optional | 8+ GB VRAM | + +### **Model Performance** +| Model Size | RAM Usage | GPU VRAM | Inference Speed | +|------------|-----------|----------|----------------| +| **7B Q4** | 4-6 GB | 4-6 GB | 10-50 tokens/sec | +| **13B Q4** | 8-10 GB | 8-10 GB | 5-25 tokens/sec | +| **70B Q4** | 40-50 GB | 40+ GB | 1-10 tokens/sec | + +--- + +## 🔍 Troubleshooting + +### **Common Issues** + +#### Installation Problems +```bash +# Missing dependencies +pip install --upgrade pip setuptools wheel +pip install -r requirements.txt + +# CUDA issues +pip install torch --index-url https://download.pytorch.org/whl/cu121 + +# Permission errors (Windows) +# Run as Administrator or use --user flag +pip install --user darkhal +``` + +#### Runtime Errors +```bash +# Model loading fails +# Check file permissions and disk space +# Verify model file integrity +# Reduce context size or GPU layers + +# Download issues +# Check internet connection +# Verify HuggingFace token +# Clear download cache +``` + +#### Performance Issues +```bash +# Slow inference +# Enable GPU acceleration +# Reduce context size +# Use quantized models +# Close other applications +``` + +### **Getting Help** +- 📖 Check the [Documentation](https://darkhal.readthedocs.io/) +- 🐛 Report [Issues](https://github.com/your-org/darkhal/issues) +- 💬 Join [Discussions](https://github.com/your-org/darkhal/discussions) +- 📧 Contact [Support](mailto:support@seteclabs.com) + +--- + +## 📄 License & Legal + +### **License** +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. + +### **Third-Party Components** +- **llama.cpp** - MIT License +- **HuggingFace Transformers** - Apache 2.0 +- **Tkinter** - Python Software Foundation License +- **Pillow** - PIL Software License + +### **Copyright** +© 2025 Setec Labs. All rights reserved. + +**Created by ssSnake** + +--- + +## 🚀 Roadmap + +### **Upcoming Features** +- [ ] **Multi-Model Support** - Run multiple models simultaneously +- [ ] **Fine-Tuning Interface** - Built-in model training capabilities +- [ ] **Plugin System** - Extensible architecture for custom tools +- [ ] **Cloud Integration** - AWS, Azure, GCP deployment options +- [ ] **Advanced Analytics** - Performance monitoring and analytics +- [ ] **Voice Interface** - Speech-to-text and text-to-speech +- [ ] **Docker Support** - Containerized deployment options +- [ ] **Model Marketplace** - Community model sharing platform + +### **Version History** +- **v2.0.0** - Major rewrite with advanced features +- **v1.5.0** - Remote control and MCP integration +- **v1.0.0** - Initial release with basic functionality + +--- + +## 🙏 Acknowledgments + +Special thanks to: +- **llama.cpp team** for the excellent inference engine +- **HuggingFace** for the model hub and libraries +- **Anthropic** for the MCP protocol specification +- **Open source community** for various dependencies and tools + +--- + +
+ +**⭐ Star this repository if you find it useful!** + +[**Download**](https://github.com/your-org/darkhal/releases) • [**Documentation**](https://darkhal.readthedocs.io/) • [**Issues**](https://github.com/your-org/darkhal/issues) • [**Discussions**](https://github.com/your-org/darkhal/discussions) + +
diff --git a/REMOTE_CONTROL_README.md b/REMOTE_CONTROL_README.md new file mode 100644 index 0000000..e20454e --- /dev/null +++ b/REMOTE_CONTROL_README.md @@ -0,0 +1,287 @@ +# LLM_Train Remote Control + +A standalone GUI application for remotely controlling and managing LLM_Train MCP servers. This allows you to connect to a running MCP server from a separate application and perform model management and inference operations remotely. + +## Features + +### 🔗 **Connection Management** +- Connect to local or remote MCP servers +- Real-time connection status monitoring +- Automatic reconnection handling +- Server path configuration and browsing + +### 🤖 **Model Management** +- List all available models from the server +- Remote model loading with configuration +- Context size and GPU layer settings +- Current model status display +- Model unloading capabilities + +### 💬 **Inference Interface** +- Text generation with configurable parameters +- Chat mode for conversational interactions +- Adjustable temperature and max tokens +- Real-time output display +- Chat history management + +### 📊 **System Monitoring** +- Real-time system information display +- GPU acceleration status +- Platform and architecture details +- Performance metrics + +### 📝 **Logging & Debugging** +- Comprehensive operation logging +- Error tracking and display +- Log saving functionality +- Connection event monitoring + +## Quick Start + +### Option 1: Batch File (Windows) +```bash +# Double-click launch_remote_control.bat +``` + +### Option 2: Direct Python Execution +```bash +python remotecontrol.py +``` + +## Requirements + +- **Python 3.7+** with tkinter support +- **Running MCP Server** (from main LLM_Train application) +- **Network access** to the MCP server (if remote) + +## Usage Guide + +### 1. Starting the Remote Control + +1. **Launch the application** using one of the methods above +2. The GUI will open with the connection panel at the top + +### 2. Connecting to MCP Server + +1. **Specify server path**: + - Default: `mcp_server.py` (local) + - Browse to select different server script + - Can be local file or network path + +2. **Click "Connect"** button +3. **Connection status** will show: + - 🔴 **Disconnected** (Red): Not connected + - 🟢 **Connected** (Green): Successfully connected + +### 3. Model Management + +**Models Tab:** +- **View available models**: Automatically populated on connection +- **Select model**: Click on model in the list +- **Configure parameters**: + - Context Size: 512 - 32768 tokens + - GPU Layers: 0 - 100 layers +- **Load model**: Click "Load Model" button +- **Monitor status**: Current model display shows loaded model + +### 4. Text Generation + +**Inference Tab:** +- **Enter prompt**: Type or paste text in the input area +- **Configure generation**: + - Max Tokens: 1 - 8192 + - Temperature: 0.0 - 2.0 +- **Generate**: Click "Generate" button +- **View output**: Results appear in the output area +- **Chat mode**: Enable for conversational interface + +### 5. System Information + +**System Tab:** +- **View system details**: Platform, architecture, acceleration +- **Monitor GPU status**: CUDA, ROCm, Metal availability +- **Check performance**: Current acceleration method +- **Refresh**: Update information in real-time + +### 6. Logging + +**Log Tab:** +- **Monitor operations**: All actions are logged with timestamps +- **Error tracking**: Errors are highlighted and detailed +- **Save logs**: Export logs to text file for debugging +- **Clear logs**: Reset log display + +## Advanced Configuration + +### Server Connection Options + +**Local Server:** +``` +Server Path: mcp_server.py +``` + +**Custom Path:** +``` +Server Path: C:\path\to\your\mcp_server.py +``` + +**Network Server (if supported):** +``` +Server Path: \\network\path\mcp_server.py +``` + +### Model Configuration + +**High-Performance Setup:** +- Context Size: 8192+ +- GPU Layers: Maximum supported +- Temperature: 0.1-0.3 for focused responses + +**Balanced Setup:** +- Context Size: 4096 +- GPU Layers: Auto-detected optimum +- Temperature: 0.7 for creative responses + +**CPU-Only Setup:** +- Context Size: 2048 +- GPU Layers: 0 +- Temperature: 0.5-1.0 + +### Generation Parameters + +**Creative Writing:** +- Max Tokens: 1024+ +- Temperature: 0.8-1.2 + +**Code Generation:** +- Max Tokens: 512 +- Temperature: 0.1-0.3 + +**Question Answering:** +- Max Tokens: 256 +- Temperature: 0.3-0.7 + +## Troubleshooting + +### Connection Issues + +**"Server file not found"** +- Verify the server path is correct +- Ensure the MCP server file exists +- Check file permissions + +**"Connection failed"** +- Ensure the MCP server is not already running +- Check if the server script is executable +- Verify Python dependencies are installed + +**"Disconnected unexpectedly"** +- Check server logs for errors +- Verify system resources are available +- Restart both applications + +### Model Loading Issues + +**"No models found"** +- Ensure model library is configured in main application +- Verify model files exist in specified directories +- Check library settings and scan depth + +**"Failed to load model"** +- Verify model file is not corrupted +- Check available system memory +- Reduce context size or GPU layers + +**"Out of memory"** +- Reduce context size +- Lower GPU layers +- Close other applications + +### Generation Issues + +**"No model loaded"** +- Load a model first using the Models tab +- Verify model loaded successfully +- Check current model display + +**"Generation timeout"** +- Reduce max tokens +- Simplify the prompt +- Check system resources + +**"Invalid parameters"** +- Verify temperature is between 0.0-2.0 +- Ensure max tokens is reasonable +- Check prompt is not empty + +## Technical Details + +### MCP Protocol +- Uses JSON-RPC 2.0 over stdin/stdout +- Asynchronous request/response handling +- Automatic request ID management +- Error handling and recovery + +### Threading Model +- Main UI thread for interface +- AsyncIO event loop for MCP communication +- Background threads for I/O operations +- Thread-safe callback system + +### Security Considerations +- Local process communication only +- No network ports exposed +- Input validation on all parameters +- Error sanitization in logs + +## Integration Examples + +### Automated Workflows +```python +# Example: Batch text generation +prompts = ["Explain AI", "Code a function", "Write a story"] +for prompt in prompts: + # Use remote control to generate text + # Save results to files +``` + +### API Integration +```python +# Example: Integration with other tools +remote_control = RemoteControlClient() +remote_control.connect("mcp_server.py") +result = remote_control.generate("Your prompt here") +``` + +### Monitoring Scripts +```python +# Example: System monitoring +while True: + system_info = remote_control.get_system_info() + log_performance_metrics(system_info) + time.sleep(60) +``` + +## Support and Development + +### Extending Functionality +- Add new MCP tool integrations +- Implement custom inference modes +- Create automation scripts +- Build monitoring dashboards + +### Contributing +- Follow Python coding standards +- Add comprehensive logging +- Include error handling +- Write unit tests + +### Reporting Issues +- Include full log output +- Specify system configuration +- Provide reproduction steps +- Attach relevant files + +## License + +This remote control application is part of the LLM_Train project and follows the same licensing terms as the main application. \ No newline at end of file diff --git a/__spy.py b/__spy.py new file mode 100644 index 0000000..5fdd622 --- /dev/null +++ b/__spy.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +""" +__spy.py +Lightweight global announcer for the currently loaded model. + +Usage: +- Call set_model(model_name, model_obj, **params) when a model is loaded. +- Retrieve with get_model(), get_model_name(), or get_info() anywhere. +""" + +from __future__ import annotations +from dataclasses import dataclass, asdict +from typing import Any, Dict, Optional +import threading + +@dataclass +class SpyData: + model_name: str + model: Any + params: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: + d = asdict(self) + # Avoid serializing the raw model object + d["model"] = repr(self.model) + return d + +_lock = threading.RLock() +_current: Optional[SpyData] = None + +def set_model(model_name: str, model: Any, **params: Any) -> None: + """Announce the current model and its load parameters.""" + global _current + with _lock: + _current = SpyData(model_name=model_name, model=model, params=dict(params or {})) + +def get_model() -> Optional[Any]: + """Return the current model object, if any.""" + with _lock: + return _current.model if _current else None + +def get_model_name() -> Optional[str]: + """Return the current model name, if any.""" + with _lock: + return _current.model_name if _current else None + +def get_info() -> Optional[SpyData]: + """Return the full SpyData object, if any.""" + with _lock: + return _current diff --git a/agent_debug_tracer.py b/agent_debug_tracer.py new file mode 100644 index 0000000..441e579 --- /dev/null +++ b/agent_debug_tracer.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +""" +Agent Debug and Trace System for DarkAgent +Custom logging and monitoring system that tracks agent lifecycle without using sys.settrace() +""" + +import time +import threading +import queue +import json +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, Optional, List + +class AgentDebugTracer: + """Custom debug and trace system for DarkAgent monitoring.""" + + def __init__(self, log_file: str = "agent_debug.log"): + self.log_file = Path(log_file) + self.trace_queue = queue.Queue() + self.is_running = False + self.logger_thread = None + self.start_time = time.time() + + # Agent state tracking + self.agent_states = {} + self.event_history = [] + self.performance_metrics = { + "total_messages": 0, + "successful_responses": 0, + "errors": 0, + "average_response_time": 0.0, + "response_times": [] + } + + # Initialize log file + self._init_log_file() + + def _init_log_file(self): + """Initialize the log file with header.""" + try: + with open(self.log_file, 'w') as f: + f.write(f"=== DarkAgent Debug Trace Started: {datetime.now()} ===\n") + f.write(f"Application Launch Time: {self.start_time}\n") + f.write("=" * 60 + "\n\n") + except Exception as e: + print(f"[AGENT_DEBUG] Failed to initialize log file: {e}") + + def start_monitoring(self): + """Start the debug monitoring system.""" + if not self.is_running: + self.is_running = True + self.logger_thread = threading.Thread(target=self._logger_worker, daemon=True) + self.logger_thread.start() + self.trace("SYSTEM", "Debug monitoring started") + + def stop_monitoring(self): + """Stop the debug monitoring system.""" + if self.is_running: + self.trace("SYSTEM", "Debug monitoring stopping") + self.is_running = False + if self.logger_thread and self.logger_thread.is_alive(): + self.logger_thread.join(timeout=1.0) + + def trace(self, category: str, message: str, data: Dict[str, Any] = None): + """Add a trace entry.""" + if not self.is_running: + return + + timestamp = time.time() + elapsed = timestamp - self.start_time + + entry = { + "timestamp": timestamp, + "elapsed": elapsed, + "datetime": datetime.now().isoformat(), + "category": category, + "message": message, + "data": data or {}, + "thread": threading.current_thread().name + } + + try: + self.trace_queue.put_nowait(entry) + self.event_history.append(entry) + + # Keep history manageable + if len(self.event_history) > 1000: + self.event_history = self.event_history[-500:] + + except queue.Full: + pass # Drop trace if queue is full + + def _logger_worker(self): + """Background thread that writes trace entries to file.""" + while self.is_running: + try: + entry = self.trace_queue.get(timeout=1.0) + self._write_trace_entry(entry) + except queue.Empty: + continue + except Exception as e: + print(f"[AGENT_DEBUG] Logger error: {e}") + + def _write_trace_entry(self, entry: Dict[str, Any]): + """Write a trace entry to the log file.""" + try: + with open(self.log_file, 'a') as f: + formatted_time = f"{entry['elapsed']:8.3f}s" + thread_info = f"[{entry['thread']}]" if entry['thread'] != 'MainThread' else "" + + f.write(f"{formatted_time} [{entry['category']}]{thread_info} {entry['message']}") + + if entry['data']: + f.write(f" | Data: {json.dumps(entry['data'], indent=None)}") + + f.write("\n") + f.flush() + + except Exception as e: + print(f"[AGENT_DEBUG] Write error: {e}") + + # Agent-specific monitoring methods + def agent_startup(self, agent_name: str, config: Dict[str, Any] = None): + """Track agent startup.""" + self.agent_states[agent_name] = { + "status": "starting", + "start_time": time.time(), + "config": config or {} + } + self.trace("AGENT_STARTUP", f"Agent {agent_name} starting", {"config": config}) + + def agent_ready(self, agent_name: str): + """Track agent ready state.""" + if agent_name in self.agent_states: + self.agent_states[agent_name]["status"] = "ready" + startup_time = time.time() - self.agent_states[agent_name]["start_time"] + self.trace("AGENT_READY", f"Agent {agent_name} ready", {"startup_time": startup_time}) + + def agent_message_start(self, agent_name: str, message: str, message_id: str = None): + """Track start of message processing.""" + self.performance_metrics["total_messages"] += 1 + self.trace("AGENT_MESSAGE_START", f"Agent {agent_name} processing message", { + "message_id": message_id, + "message_preview": message[:100] + "..." if len(message) > 100 else message + }) + return time.time() # Return start time for response time calculation + + def agent_message_end(self, agent_name: str, message_id: str, start_time: float, success: bool = True, error: str = None): + """Track end of message processing.""" + response_time = time.time() - start_time + self.performance_metrics["response_times"].append(response_time) + + if success: + self.performance_metrics["successful_responses"] += 1 + self.trace("AGENT_MESSAGE_SUCCESS", f"Agent {agent_name} completed message", { + "message_id": message_id, + "response_time": response_time + }) + else: + self.performance_metrics["errors"] += 1 + self.trace("AGENT_MESSAGE_ERROR", f"Agent {agent_name} message failed", { + "message_id": message_id, + "response_time": response_time, + "error": error + }) + + # Update average response time + if self.performance_metrics["response_times"]: + self.performance_metrics["average_response_time"] = sum(self.performance_metrics["response_times"]) / len(self.performance_metrics["response_times"]) + + def agent_shutdown(self, agent_name: str): + """Track agent shutdown.""" + if agent_name in self.agent_states: + self.agent_states[agent_name]["status"] = "shutdown" + uptime = time.time() - self.agent_states[agent_name]["start_time"] + self.trace("AGENT_SHUTDOWN", f"Agent {agent_name} shutting down", {"uptime": uptime}) + + def agent_error(self, agent_name: str, error: str, context: Dict[str, Any] = None): + """Track agent errors.""" + self.performance_metrics["errors"] += 1 + self.trace("AGENT_ERROR", f"Agent {agent_name} error: {error}", context) + + def ui_event(self, event_type: str, details: Dict[str, Any] = None): + """Track UI events related to agent.""" + self.trace("UI_EVENT", event_type, details) + + def model_event(self, event_type: str, model_info: Dict[str, Any] = None): + """Track model loading/unloading events.""" + self.trace("MODEL_EVENT", event_type, model_info) + + def get_performance_summary(self) -> Dict[str, Any]: + """Get current performance metrics.""" + return { + "uptime": time.time() - self.start_time, + "agent_states": self.agent_states, + "performance": self.performance_metrics, + "recent_events": self.event_history[-10:] if self.event_history else [] + } + + def print_summary(self): + """Print performance summary to console.""" + summary = self.get_performance_summary() + print("\n=== DarkAgent Debug Summary ===") + print(f"Uptime: {summary['uptime']:.2f}s") + print(f"Total Messages: {summary['performance']['total_messages']}") + print(f"Successful Responses: {summary['performance']['successful_responses']}") + print(f"Errors: {summary['performance']['errors']}") + print(f"Average Response Time: {summary['performance']['average_response_time']:.3f}s") + print(f"Active Agents: {len([a for a in summary['agent_states'].values() if a['status'] == 'ready'])}") + print("=" * 31) + + +# Global tracer instance +_global_tracer: Optional[AgentDebugTracer] = None + +def get_tracer() -> AgentDebugTracer: + """Get the global tracer instance.""" + global _global_tracer + if _global_tracer is None: + _global_tracer = AgentDebugTracer() + _global_tracer.start_monitoring() + return _global_tracer + +def trace(category: str, message: str, data: Dict[str, Any] = None): + """Convenience function for tracing.""" + get_tracer().trace(category, message, data) + +def shutdown_tracer(): + """Shutdown the global tracer.""" + global _global_tracer + if _global_tracer: + _global_tracer.stop_monitoring() + _global_tracer = None \ No newline at end of file diff --git a/agent_dhal/__init__.py b/agent_dhal/__init__.py new file mode 100644 index 0000000..3fb86f3 --- /dev/null +++ b/agent_dhal/__init__.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +""" +AgentDhal - Complete AI Agent Framework for DarkHal 2.0 + +A comprehensive agent framework providing: +- Multi-agent conversation capabilities +- Agent orchestration and team management +- Tool integration and function calling +- Model context management +- Memory and state management +- Customizable agent behaviors + +Legal Attribution: +This software is based on Microsoft AutoGen (https://github.com/microsoft/autogen) +Licensed under MIT License. AgentDhal is a derivative work with +modifications and extensions for the DarkHal project. + +Copyright (c) 2025 DarkHal Project +""" + +__version__ = "1.0.0" +__author__ = "DarkHal Project (based on Microsoft AutoGen)" + +# Import core AgentDhal components +from .agentdhal_core import ( + Agent, + AgentId, + AgentRuntime, + SingleThreadedAgentRuntime, + RoutedAgent, + MessageContext, + DefaultTopicId, + message_handler, + default_subscription, + BaseAgent, + AgentType, + TopicId, + Subscription +) + +# Import Dhal - our primary AI agent +from .hal import Dhal, DhalConfig, create_dhal + +# Import other AgentDhal components (available but not primary focus) +try: + from .agentdhal_agentchat import ( + AssistantAgent, + UserProxyAgent, + ChatAgent, + Team + ) +except ImportError: + # Graceful fallback if agentchat modules have issues + AssistantAgent = None + UserProxyAgent = None + ChatAgent = None + Team = None + +__all__ = [ + # Core framework + "Agent", + "AgentId", + "AgentRuntime", + "SingleThreadedAgentRuntime", + "RoutedAgent", + "MessageContext", + "DefaultTopicId", + "message_handler", + "default_subscription", + "BaseAgent", + "AgentType", + "TopicId", + "Subscription", + + # Primary Hal Agent + "Hal", + "HalConfig", + "create_hal", + + # Additional Agent Components (if available) + "AssistantAgent", + "UserProxyAgent", + "ChatAgent", + "Team" +] \ No newline at end of file diff --git a/agent_dhal/agentdhal_agentchat/__init__.py b/agent_dhal/agentdhal_agentchat/__init__.py new file mode 100644 index 0000000..2ee8beb --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/__init__.py @@ -0,0 +1,14 @@ +""" +This module provides the main entry point for the agentdhal_agentchat package. +It includes logger names for trace and event logs, and retrieves the package version. +""" + +import importlib.metadata + +TRACE_LOGGER_NAME = "agentdhal_agentchat" +"""Logger name for trace logs.""" + +EVENT_LOGGER_NAME = "agentdhal_agentchat.events" +"""Logger name for event logs.""" + +__version__ = importlib.metadata.version("agentdhal_agentchat") diff --git a/agent_dhal/agentdhal_agentchat/agents/__init__.py b/agent_dhal/agentdhal_agentchat/agents/__init__.py new file mode 100644 index 0000000..ebce7b8 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/agents/__init__.py @@ -0,0 +1,25 @@ +""" +This module initializes various pre-defined agents provided by the package. +BaseChatAgent is the base class for all agents in AgentChat. +""" + +from ._assistant_agent import AssistantAgent +from ._base_chat_agent import BaseChatAgent +from ._code_executor_agent import ApprovalFuncType, ApprovalRequest, ApprovalResponse, CodeExecutorAgent +from ._message_filter_agent import MessageFilterAgent, MessageFilterConfig, PerSourceFilter +from ._society_of_mind_agent import SocietyOfMindAgent +from ._user_proxy_agent import UserProxyAgent + +__all__ = [ + "BaseChatAgent", + "AssistantAgent", + "CodeExecutorAgent", + "SocietyOfMindAgent", + "UserProxyAgent", + "MessageFilterAgent", + "MessageFilterConfig", + "PerSourceFilter", + "ApprovalRequest", + "ApprovalResponse", + "ApprovalFuncType", +] diff --git a/agent_dhal/agentdhal_agentchat/agents/_assistant_agent.py b/agent_dhal/agentdhal_agentchat/agents/_assistant_agent.py new file mode 100644 index 0000000..cc10d69 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/agents/_assistant_agent.py @@ -0,0 +1,1699 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import uuid +import warnings +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) + +from agentdhal_core import CancellationToken, Component, ComponentModel, FunctionCall +from agentdhal_core.memory import Memory +from agentdhal_core.model_context import ( + ChatCompletionContext, + UnboundedChatCompletionContext, +) +from agentdhal_core.models import ( + AssistantMessage, + ChatCompletionClient, + CreateResult, + FunctionExecutionResult, + FunctionExecutionResultMessage, + LLMMessage, + SystemMessage, +) +from agentdhal_core.tools import BaseTool, FunctionTool, StaticStreamWorkbench, ToolResult, Workbench +from pydantic import BaseModel, Field +from typing_extensions import Self + +from .. import EVENT_LOGGER_NAME +from ..base import Handoff as HandoffBase +from ..base import Response +from ..messages import ( + BaseAgentEvent, + BaseChatMessage, + HandoffMessage, + MemoryQueryEvent, + ModelClientStreamingChunkEvent, + StructuredMessage, + StructuredMessageFactory, + TextMessage, + ThoughtEvent, + ToolCallExecutionEvent, + ToolCallRequestEvent, + ToolCallSummaryMessage, +) +from ..state import AssistantAgentState +from ..utils import remove_images +from ._base_chat_agent import BaseChatAgent + +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + +# Add type variables for more specific typing +T = TypeVar("T", bound=BaseModel) +R = TypeVar("R", bound=BaseModel) + + +class HalConfig(BaseModel): + """The declarative configuration for the assistant agent.""" + + name: str + model_client: ComponentModel + tools: List[ComponentModel] | None = None + workbench: List[ComponentModel] | None = None + handoffs: List[HandoffBase | str] | None = None + model_context: ComponentModel | None = None + memory: List[ComponentModel] | None = None + description: str + system_message: str | None = None + model_client_stream: bool = False + reflect_on_tool_use: bool + tool_call_summary_format: str + max_tool_iterations: int = Field(default=1, ge=1) + metadata: Dict[str, str] | None = None + structured_message_factory: ComponentModel | None = None + + +class Hal(BaseChatAgent, Component[AssistantAgentConfig]): + """An agent that provides assistance with tool use. + The :meth:`on_messages` returns a :class:`~agentdhal_agentchat.base.Response` + in which :attr:`~agentdhal_agentchat.base.Response.chat_message` is the final + response message. + + The :meth:`on_messages_stream` creates an async generator that produces + the inner messages as they are created, and the :class:`~agentdhal_agentchat.base.Response` + object as the last item before closing the generator. + + The :meth:`BaseChatAgent.run` method returns a :class:`~agentdhal_agentchat.base.TaskResult` + containing the messages produced by the agent. In the list of messages, + :attr:`~agentdhal_agentchat.base.TaskResult.messages`, + the last message is the final response message. + + The :meth:`BaseChatAgent.run_stream` method creates an async generator that produces + the inner messages as they are created, and the :class:`~agentdhal_agentchat.base.TaskResult` + object as the last item before closing the generator. + + .. attention:: + + The caller must only pass the new messages to the agent on each call + to the :meth:`on_messages`, :meth:`on_messages_stream`, :meth:`BaseChatAgent.run`, + or :meth:`BaseChatAgent.run_stream` methods. + The agent maintains its state between calls to these methods. + Do not pass the entire conversation history to the agent on each call. + + .. warning:: + The assistant agent is not thread-safe or coroutine-safe. + It should not be shared between multiple tasks or coroutines, and it should + not call its methods concurrently. + + The following diagram shows how the assistant agent works: + + .. image:: ../../images/assistant-agent.svg + + **Structured output:** + + If the `output_content_type` is set, the agent will respond with a :class:`~agentdhal_agentchat.messages.StructuredMessage` + instead of a :class:`~agentdhal_agentchat.messages.TextMessage` in the final response by default. + + .. note:: + + Currently, setting `output_content_type` prevents the agent from being + able to call `load_component` and `dum_component` methods for serializable + configuration. This will be fixed soon in the future. + + **Tool call behavior:** + + * If the model returns no tool call, then the response is immediately returned as a :class:`~agentdhal_agentchat.messages.TextMessage` or a :class:`~agentdhal_agentchat.messages.StructuredMessage` (when using structured output) in :attr:`~agentdhal_agentchat.base.Response.chat_message`. This ends the tool call iteration loop regardless of the `max_tool_iterations` setting. + * When the model returns tool calls, they will be executed right away: + - When `reflect_on_tool_use` is False, the tool call results are returned as a :class:`~agentdhal_agentchat.messages.ToolCallSummaryMessage` in :attr:`~agentdhal_agentchat.base.Response.chat_message`. You can customise the summary with either a static format string (`tool_call_summary_format`) **or** a callable (`tool_call_summary_formatter`); the callable is evaluated once per tool call. + - When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and final response is returned as a :class:`~agentdhal_agentchat.messages.TextMessage` or a :class:`~agentdhal_agentchat.messages.StructuredMessage` (when using structured output) in :attr:`~agentdhal_agentchat.base.Response.chat_message`. + - `reflect_on_tool_use` is set to `True` by default when `output_content_type` is set. + - `reflect_on_tool_use` is set to `False` by default when `output_content_type` is not set. + * If the model returns multiple tool calls, they will be executed concurrently. To disable parallel tool calls you need to configure the model client. For example, set `parallel_tool_calls=False` for :class:`~agentdhal_extensions.models.openai.OpenAIChatCompletionClient` and :class:`~agentdhal_extensions.models.openai.AzureOpenAIChatCompletionClient`. + * The `max_tool_iterations` parameter controls how many sequential tool call iterations the agent can perform in a single run. When set to 1 (default), the agent executes tool calls once and returns the result. When set higher, the agent can make additional model calls to execute more tool calls if the model continues to request them, enabling multi-step tool-based workflows. The agent stops when either the model returns a text response (instead of tool calls) or the maximum number of iterations is reached. + + .. tip:: + + By default, the tool call results are returned as the response when tool + calls are made, so pay close attention to how the tools' return values + are formatted—especially if another agent expects a specific schema. + + * Use **`tool_call_summary_format`** for a simple static template. + * Use **`tool_call_summary_formatter`** for full programmatic control + (e.g., "hide large success payloads, show full details on error"). + + *Note*: `tool_call_summary_formatter` is **not serializable** and will + be ignored when an agent is loaded from, or exported to, YAML/JSON + configuration files. + + + **Hand off behavior:** + + * If a handoff is triggered, a :class:`~agentdhal_agentchat.messages.HandoffMessage` will be returned in :attr:`~agentdhal_agentchat.base.Response.chat_message`. + * If there are tool calls, they will also be executed right away before returning the handoff. + * The tool calls and results are passed to the target agent through :attr:`~agentdhal_agentchat.messages.HandoffMessage.context`. + + + .. note:: + If multiple handoffs are detected, only the first handoff is executed. + To avoid this, disable parallel tool calls in the model client configuration. + + + **Limit context size sent to the model:** + + You can limit the number of messages sent to the model by setting + the `model_context` parameter to a :class:`~agentdhal_core.model_context.BufferedChatCompletionContext`. + This will limit the number of recent messages sent to the model and can be useful + when the model has a limit on the number of tokens it can process. + Another option is to use a :class:`~agentdhal_core.model_context.TokenLimitedChatCompletionContext` + which will limit the number of tokens sent to the model. + You can also create your own model context by subclassing + :class:`~agentdhal_core.model_context.ChatCompletionContext`. + + **Streaming mode:** + + The assistant agent can be used in streaming mode by setting `model_client_stream=True`. + In this mode, the :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield + :class:`~agentdhal_agentchat.messages.ModelClientStreamingChunkEvent` + messages as the model client produces chunks of response. + The chunk messages will not be included in the final response's inner messages. + + Args: + name (str): The name of the agent. + model_client (ChatCompletionClient): The model client to use for inference. + tools (List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent. + workbench (Workbench | Sequence[Workbench] | None, optional): The workbench or list of workbenches to use for the agent. + Tools cannot be used when workbench is set and vice versa. + handoffs (List[HandoffBase | str] | None, optional): The handoff configurations for the agent, + allowing it to transfer to other agents by responding with a :class:`HandoffMessage`. + The transfer is only executed when the team is in :class:`~agentdhal_agentchat.teams.Swarm`. + If a handoff is a string, it should represent the target agent's name. + model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~agentdhal_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset. + description (str, optional): The description of the agent. + system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable. + model_client_stream (bool, optional): If `True`, the model client will be used in streaming mode. + :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield :class:`~agentdhal_agentchat.messages.ModelClientStreamingChunkEvent` + messages as the model client produces chunks of response. Defaults to `False`. + reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result + to generate a response. If `False`, the tool call result will be returned as the response. By default, if `output_content_type` is set, this will be `True`; + if `output_content_type` is not set, this will be `False`. + output_content_type (type[BaseModel] | None, optional): The output content type for :class:`~agentdhal_agentchat.messages.StructuredMessage` response as a Pydantic model. + This will be used with the model client to generate structured output. + If this is set, the agent will respond with a :class:`~agentdhal_agentchat.messages.StructuredMessage` instead of a :class:`~agentdhal_agentchat.messages.TextMessage` + in the final response, unless `reflect_on_tool_use` is `False` and a tool call is made. + output_content_type_format (str | None, optional): (Experimental) The format string used for the content of a :class:`~agentdhal_agentchat.messages.StructuredMessage` response. + max_tool_iterations (int, optional): The maximum number of tool iterations to perform until the model stops making tool calls. Defaults to `1`, which means the agent will + only execute the tool calls made by the model once, and return the result as a :class:`~agentdhal_agentchat.messages.ToolCallSummaryMessage`, + or a :class:`~agentdhal_agentchat.messages.TextMessage` or a :class:`~agentdhal_agentchat.messages.StructuredMessage` (when using structured output) + in :attr:`~agentdhal_agentchat.base.Response.chat_message` as the final response. + As soon as the model stops making tool calls, the agent will stop executing tool calls and return the result as the final response. + The value must be greater than or equal to 1. + tool_call_summary_format (str, optional): Static format string applied to each tool call result when composing the :class:`~agentdhal_agentchat.messages.ToolCallSummaryMessage`. + Defaults to ``"{result}"``. Ignored if `tool_call_summary_formatter` is provided. When `reflect_on_tool_use` is ``False``, the summaries for all tool + calls are concatenated with a newline ('\\n') and returned as the response. Placeholders available in the template: + `{tool_name}`, `{arguments}`, `{result}`, `{is_error}`. + tool_call_summary_formatter (Callable[[FunctionCall, FunctionExecutionResult], str] | None, optional): + Callable that receives the ``FunctionCall`` and its ``FunctionExecutionResult`` and returns the summary string. + Overrides `tool_call_summary_format` when supplied and allows conditional logic — for example, emitting static string like + ``"Tool FooBar executed successfully."`` on success and a full payload (including all passed arguments etc.) only on failure. + + **Limitation**: The callable is *not serializable*; values provided via YAML/JSON configs are ignored. + + .. note:: + + `tool_call_summary_formatter` is intended for in-code use only. It cannot currently be saved or restored via + configuration files. + + memory (Sequence[Memory] | None, optional): The memory store to use for the agent. Defaults to `None`. + metadata (Dict[str, str] | None, optional): Optional metadata for tracking. + + Raises: + ValueError: If tool names are not unique. + ValueError: If handoff names are not unique. + ValueError: If handoff names are not unique from tool names. + ValueError: If maximum number of tool iterations is less than 1. + + Examples: + + **Example 1: basic agent** + + The following example demonstrates how to create an assistant agent with + a model client and generate a response to a simple task. + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + agent = AssistantAgent(name="assistant", model_client=model_client) + + result = await agent.run(task="Name two cities in North America.") + print(result) + + + asyncio.run(main()) + + **Example 2: model client token streaming** + + This example demonstrates how to create an assistant agent with + a model client and generate a token stream by setting `model_client_stream=True`. + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + agent = AssistantAgent( + name="assistant", + model_client=model_client, + model_client_stream=True, + ) + + stream = agent.run_stream(task="Name two cities in North America.") + async for message in stream: + print(message) + + + asyncio.run(main()) + + .. code-block:: text + + source='user' models_usage=None metadata={} content='Name two cities in North America.' type='TextMessage' + source='assistant' models_usage=None metadata={} content='Two' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content=' cities' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content=' in' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content=' North' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content=' America' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content=' are' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content=' New' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content=' York' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content=' City' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content=' and' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content=' Toronto' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content='.' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content=' TERMIN' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=None metadata={} content='ATE' type='ModelClientStreamingChunkEvent' + source='assistant' models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0) metadata={} content='Two cities in North America are New York City and Toronto. TERMINATE' type='TextMessage' + messages=[TextMessage(source='user', models_usage=None, metadata={}, content='Name two cities in North America.', type='TextMessage'), TextMessage(source='assistant', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content='Two cities in North America are New York City and Toronto. TERMINATE', type='TextMessage')] stop_reason=None + + + **Example 3: agent with tools** + + The following example demonstrates how to create an assistant agent with + a model client and a tool, generate a stream of messages for a task, and + print the messages to the console using :class:`~agentdhal_agentchat.ui.Console`. + + The tool is a simple function that returns the current time. + Under the hood, the function is wrapped in a :class:`~agentdhal_core.tools.FunctionTool` + and used with the agent's model client. The doc string of the function + is used as the tool description, the function name is used as the tool name, + and the function signature including the type hints is used as the tool arguments. + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.ui import Console + + + async def get_current_time() -> str: + return "The current time is 12:00 PM." + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time]) + await Console(agent.run_stream(task="What is the current time?")) + + + asyncio.run(main()) + + **Example 4: agent with max_tool_iterations** + + The following example demonstrates how to use the `max_tool_iterations` parameter + to control how many times the agent can execute tool calls in a single run. + This is useful when you want the agent to perform multiple sequential tool + operations to reach a goal. + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.ui import Console + + + # Global counter state + counter = 0 + + + def increment_counter() -> str: + \"\"\"Increment the counter by 1 and return the current value.\"\"\" + global counter + counter += 1 + return f"Counter incremented to: {counter}" + + + def get_counter() -> str: + \"\"\"Get the current counter value.\"\"\" + global counter + return f"Current counter value: {counter}" + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + + # Create agent with max_tool_iterations=5 to allow multiple tool calls + agent = AssistantAgent( + name="assistant", + model_client=model_client, + tools=[increment_counter, get_counter], + max_tool_iterations=5, # Allow up to 5 tool call iterations + reflect_on_tool_use=True, # Get a final summary after tool calls + ) + + await Console(agent.run_stream(task="Increment the counter 3 times and then tell me the final value.")) + + + asyncio.run(main()) + + **Example 5: agent with Model-Context Protocol (MCP) workbench** + + The following example demonstrates how to create an assistant agent with + a model client and an :class:`~agentdhal_extensions.tools.mcp.McpWorkbench` for + interacting with a Model-Context Protocol (MCP) server. + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.mcp import StdioServerParams, McpWorkbench + + + async def main() -> None: + params = StdioServerParams( + command="uvx", + args=["mcp-server-fetch"], + read_timeout_seconds=60, + ) + + # You can also use `start()` and `stop()` to manage the session. + async with McpWorkbench(server_params=params) as workbench: + model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") + assistant = AssistantAgent( + name="Assistant", + model_client=model_client, + workbench=workbench, + reflect_on_tool_use=True, + ) + await Console( + assistant.run_stream(task="Go to https://github.com/microsoft/autogen and tell me what you see.") + ) + + + asyncio.run(main()) + + **Example 6: agent with structured output and tool** + + The following example demonstrates how to create an assistant agent with + a model client configured to use structured output and a tool. + Note that you need to use :class:`~agentdhal_core.tools.FunctionTool` to create the tool + and the `strict=True` is required for structured output mode. + Because the model is configured to use structured output, the output + reflection response will be a JSON formatted string. + + .. code-block:: python + + import asyncio + from typing import Literal + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.ui import Console + from agentdhal_core.tools import FunctionTool + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from pydantic import BaseModel + + + # Define the structured output format. + class AgentResponse(BaseModel): + thoughts: str + response: Literal["happy", "sad", "neutral"] + + + # Define the function to be called as a tool. + def sentiment_analysis(text: str) -> str: + \"\"\"Given a text, return the sentiment.\"\"\" + return "happy" if "happy" in text else "sad" if "sad" in text else "neutral" + + + # Create a FunctionTool instance with `strict=True`, + # which is required for structured output mode. + tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True) + + # Create an OpenAIChatCompletionClient instance that supports structured output. + model_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + ) + + # Create an AssistantAgent instance that uses the tool and model client. + agent = AssistantAgent( + name="assistant", + model_client=model_client, + tools=[tool], + system_message="Use the tool to analyze sentiment.", + output_content_type=AgentResponse, + ) + + + async def main() -> None: + stream = agent.run_stream(task="I am happy today!") + await Console(stream) + + + asyncio.run(main()) + + .. code-block:: text + + ---------- assistant ---------- + [FunctionCall(id='call_tIZjAVyKEDuijbBwLY6RHV2p', arguments='{"text":"I am happy today!"}', name='sentiment_analysis')] + ---------- assistant ---------- + [FunctionExecutionResult(content='happy', call_id='call_tIZjAVyKEDuijbBwLY6RHV2p', is_error=False)] + ---------- assistant ---------- + {"thoughts":"The user expresses a clear positive emotion by stating they are happy today, suggesting an upbeat mood.","response":"happy"} + + **Example 7: agent with bounded model context** + + The following example shows how to use a + :class:`~agentdhal_core.model_context.BufferedChatCompletionContext` + that only keeps the last 2 messages (1 user + 1 assistant). + Bounded model context is useful when the model has a limit on the + number of tokens it can process. + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_core.model_context import BufferedChatCompletionContext + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main() -> None: + # Create a model client. + model_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + # api_key = "your_openai_api_key" + ) + + # Create a model context that only keeps the last 2 messages (1 user + 1 assistant). + model_context = BufferedChatCompletionContext(buffer_size=2) + + # Create an AssistantAgent instance with the model client and context. + agent = AssistantAgent( + name="assistant", + model_client=model_client, + model_context=model_context, + system_message="You are a helpful assistant.", + ) + + result = await agent.run(task="Name two cities in North America.") + print(result.messages[-1].content) # type: ignore + + result = await agent.run(task="My favorite color is blue.") + print(result.messages[-1].content) # type: ignore + + result = await agent.run(task="Did I ask you any question?") + print(result.messages[-1].content) # type: ignore + + + asyncio.run(main()) + + .. code-block:: text + + Two cities in North America are New York City and Toronto. + That's great! Blue is often associated with calmness and serenity. Do you have a specific shade of blue that you like, or any particular reason why it's your favorite? + No, you didn't ask a question. I apologize for any misunderstanding. If you have something specific you'd like to discuss or ask, feel free to let me know! + + **Example 8: agent with memory** + + The following example shows how to use a list-based memory with the assistant agent. + The memory is preloaded with some initial content. + Under the hood, the memory is used to update the model context + before making an inference, using the :meth:`~agentdhal_core.memory.Memory.update_context` method. + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_core.memory import ListMemory, MemoryContent + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main() -> None: + # Create a model client. + model_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + # api_key = "your_openai_api_key" + ) + + # Create a list-based memory with some initial content. + memory = ListMemory() + await memory.add(MemoryContent(content="User likes pizza.", mime_type="text/plain")) + await memory.add(MemoryContent(content="User dislikes cheese.", mime_type="text/plain")) + + # Create an AssistantAgent instance with the model client and memory. + agent = AssistantAgent( + name="assistant", + model_client=model_client, + memory=[memory], + system_message="You are a helpful assistant.", + ) + + result = await agent.run(task="What is a good dinner idea?") + print(result.messages[-1].content) # type: ignore + + + asyncio.run(main()) + + .. code-block:: text + + How about making a delicious pizza without cheese? You can create a flavorful veggie pizza with a variety of toppings. Here's a quick idea: + + **Veggie Tomato Sauce Pizza** + - Start with a pizza crust (store-bought or homemade). + - Spread a layer of marinara or tomato sauce evenly over the crust. + - Top with your favorite vegetables like bell peppers, mushrooms, onions, olives, and spinach. + - Add some protein if you'd like, such as grilled chicken or pepperoni (ensure it's cheese-free). + - Sprinkle with herbs like oregano and basil, and maybe a drizzle of olive oil. + - Bake according to the crust instructions until the edges are golden and the veggies are cooked. + + Serve it with a side salad or some garlic bread to complete the meal! Enjoy your dinner! + + **Example 9: agent with `o1-mini`** + + The following example shows how to use `o1-mini` model with the assistant agent. + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="o1-mini", + # api_key = "your_openai_api_key" + ) + # The system message is not supported by the o1 series model. + agent = AssistantAgent(name="assistant", model_client=model_client, system_message=None) + + result = await agent.run(task="What is the capital of France?") + print(result.messages[-1].content) # type: ignore + + + asyncio.run(main()) + + .. note:: + + The `o1-preview` and `o1-mini` models do not support system message and function calling. + So the `system_message` should be set to `None` and the `tools` and `handoffs` should not be set. + See `o1 beta limitations `_ for more details. + + + **Example 10: agent using reasoning model with custom model context.** + + The following example shows how to use a reasoning model (DeepSeek R1) with the assistant agent. + The model context is used to filter out the thought field from the assistant message. + + .. code-block:: python + + import asyncio + from typing import List + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_core.model_context import UnboundedChatCompletionContext + from agentdhal_core.models import AssistantMessage, LLMMessage, ModelFamily + from agentdhal_extensions.models.ollama import OllamaChatCompletionClient + + + class ReasoningModelContext(UnboundedChatCompletionContext): + \"\"\"A model context for reasoning models.\"\"\" + + async def get_messages(self) -> List[LLMMessage]: + messages = await super().get_messages() + # Filter out thought field from AssistantMessage. + messages_out: List[LLMMessage] = [] + for message in messages: + if isinstance(message, AssistantMessage): + message.thought = None + messages_out.append(message) + return messages_out + + + # Create an instance of the model client for DeepSeek R1 hosted locally on Ollama. + model_client = OllamaChatCompletionClient( + model="deepseek-r1:8b", + model_info={ + "vision": False, + "function_calling": False, + "json_output": False, + "family": ModelFamily.R1, + "structured_output": True, + }, + ) + + agent = AssistantAgent( + "reasoning_agent", + model_client=model_client, + model_context=ReasoningModelContext(), # Use the custom model context. + ) + + + async def run_reasoning_agent() -> None: + result = await agent.run(task="What is the capital of France?") + print(result) + + + asyncio.run(run_reasoning_agent()) + + For detailed examples and usage, see the Examples section below. + """ + + component_version = 2 + component_config_schema = AssistantAgentConfig + component_provider_override = "agentdhal_agentchat.agents.AssistantAgent" + + def __init__( + self, + name: str, + model_client: ChatCompletionClient, + *, + tools: List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, + workbench: Workbench | Sequence[Workbench] | None = None, + handoffs: List[HandoffBase | str] | None = None, + model_context: ChatCompletionContext | None = None, + description: str = "An agent that provides assistance with ability to use tools.", + system_message: ( + str | None + ) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", + model_client_stream: bool = False, + reflect_on_tool_use: bool | None = None, + max_tool_iterations: int = 1, + tool_call_summary_format: str = "{result}", + tool_call_summary_formatter: Callable[[FunctionCall, FunctionExecutionResult], str] | None = None, + output_content_type: type[BaseModel] | None = None, + output_content_type_format: str | None = None, + memory: Sequence[Memory] | None = None, + metadata: Dict[str, str] | None = None, + ): + super().__init__(name=name, description=description) + self._metadata = metadata or {} + self._model_client = model_client + self._model_client_stream = model_client_stream + self._output_content_type: type[BaseModel] | None = output_content_type + self._output_content_type_format = output_content_type_format + self._structured_message_factory: StructuredMessageFactory | None = None + if output_content_type is not None: + self._structured_message_factory = StructuredMessageFactory( + input_model=output_content_type, format_string=output_content_type_format + ) + + self._memory = None + if memory is not None: + if isinstance(memory, list): + self._memory = memory + else: + raise TypeError(f"Expected Memory, List[Memory], or None, got {type(memory)}") + + self._system_messages: List[SystemMessage] = [] + if system_message is None: + self._system_messages = [] + else: + self._system_messages = [SystemMessage(content=system_message)] + self._tools: List[BaseTool[Any, Any]] = [] + if tools is not None: + if model_client.model_info["function_calling"] is False: + raise ValueError("The model does not support function calling.") + for tool in tools: + if isinstance(tool, BaseTool): + self._tools.append(tool) + elif callable(tool): + if hasattr(tool, "__doc__") and tool.__doc__ is not None: + description = tool.__doc__ + else: + description = "" + self._tools.append(FunctionTool(tool, description=description)) + else: + raise ValueError(f"Unsupported tool type: {type(tool)}") + # Check if tool names are unique. + tool_names = [tool.name for tool in self._tools] + if len(tool_names) != len(set(tool_names)): + raise ValueError(f"Tool names must be unique: {tool_names}") + + # Handoff tools. + self._handoff_tools: List[BaseTool[Any, Any]] = [] + self._handoffs: Dict[str, HandoffBase] = {} + if handoffs is not None: + if model_client.model_info["function_calling"] is False: + raise ValueError("The model does not support function calling, which is needed for handoffs.") + for handoff in handoffs: + if isinstance(handoff, str): + handoff = HandoffBase(target=handoff) + if isinstance(handoff, HandoffBase): + self._handoff_tools.append(handoff.handoff_tool) + self._handoffs[handoff.name] = handoff + else: + raise ValueError(f"Unsupported handoff type: {type(handoff)}") + # Check if handoff tool names are unique. + handoff_tool_names = [tool.name for tool in self._handoff_tools] + if len(handoff_tool_names) != len(set(handoff_tool_names)): + raise ValueError(f"Handoff names must be unique: {handoff_tool_names}") + # Create sets for faster lookup + tool_names_set = set(tool_names) + handoff_tool_names_set = set(handoff_tool_names) + + # Check if there's any overlap between handoff tool names and tool names + overlap = tool_names_set.intersection(handoff_tool_names_set) + + # Also check if any handoff target name matches a tool name + # This handles the case where a handoff is specified directly with a string that matches a tool name + for handoff in handoffs or []: + if isinstance(handoff, str) and handoff in tool_names_set: + raise ValueError("Handoff names must be unique from tool names") + elif isinstance(handoff, HandoffBase) and handoff.target in tool_names_set: + raise ValueError("Handoff names must be unique from tool names") + + if overlap: + raise ValueError("Handoff names must be unique from tool names") + + if workbench is not None: + if self._tools: + raise ValueError("Tools cannot be used with a workbench.") + if isinstance(workbench, Sequence): + self._workbench = workbench + else: + self._workbench = [workbench] + else: + self._workbench = [StaticStreamWorkbench(self._tools)] + + if model_context is not None: + self._model_context = model_context + else: + self._model_context = UnboundedChatCompletionContext() + + if self._output_content_type is not None and reflect_on_tool_use is None: + # If output_content_type is set, we need to reflect on tool use by default. + self._reflect_on_tool_use = True + elif reflect_on_tool_use is None: + self._reflect_on_tool_use = False + else: + self._reflect_on_tool_use = reflect_on_tool_use + + # Tool call loop + self._max_tool_iterations = max_tool_iterations + if self._max_tool_iterations < 1: + raise ValueError( + f"Maximum number of tool iterations must be greater than or equal to 1, got {max_tool_iterations}" + ) + + self._tool_call_summary_format = tool_call_summary_format + self._tool_call_summary_formatter = tool_call_summary_formatter + self._is_running = False + + @property + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + """Get the types of messages this agent can produce. + + Returns: + Sequence of message types this agent can generate + """ + types: List[type[BaseChatMessage]] = [TextMessage, ToolCallSummaryMessage, HandoffMessage] + if self._structured_message_factory is not None: + types.append(StructuredMessage) + return types + + @property + def model_context(self) -> ChatCompletionContext: + """Get the model context used by this agent. + + Returns: + The chat completion context for this agent + """ + return self._model_context + + async def on_messages( + self, + messages: Sequence[BaseChatMessage], + cancellation_token: CancellationToken, + ) -> Response: + """Process incoming messages and generate a response. + + Args: + messages: Sequence of messages to process + cancellation_token: Token for cancelling operation + + Returns: + Response containing the agent's reply + """ + async for message in self.on_messages_stream(messages, cancellation_token): + if isinstance(message, Response): + return message + raise AssertionError("The stream should have returned the final result.") + + async def on_messages_stream( + self, + messages: Sequence[BaseChatMessage], + cancellation_token: CancellationToken, + ) -> AsyncGenerator[Union[BaseAgentEvent, BaseChatMessage, Response], None]: + """Process messages and stream the response. + + Args: + messages: Sequence of messages to process + cancellation_token: Token for cancelling operation + + Yields: + Events, messages and final response during processing + """ + + # Gather all relevant state here + agent_name = self.name + model_context = self._model_context + memory = self._memory + system_messages = self._system_messages + workbench = self._workbench + handoff_tools = self._handoff_tools + handoffs = self._handoffs + model_client = self._model_client + model_client_stream = self._model_client_stream + reflect_on_tool_use = self._reflect_on_tool_use + max_tool_iterations = self._max_tool_iterations + tool_call_summary_format = self._tool_call_summary_format + tool_call_summary_formatter = self._tool_call_summary_formatter + output_content_type = self._output_content_type + + # STEP 1: Add new user/handoff messages to the model context + await self._add_messages_to_context( + model_context=model_context, + messages=messages, + ) + + # STEP 2: Update model context with any relevant memory + inner_messages: List[BaseAgentEvent | BaseChatMessage] = [] + for event_msg in await self._update_model_context_with_memory( + memory=memory, + model_context=model_context, + agent_name=agent_name, + ): + inner_messages.append(event_msg) + yield event_msg + + # STEP 3: Generate a message ID for correlation between streaming chunks and final message + message_id = str(uuid.uuid4()) + + # STEP 4: Run the first inference + model_result = None + async for inference_output in self._call_llm( + model_client=model_client, + model_client_stream=model_client_stream, + system_messages=system_messages, + model_context=model_context, + workbench=workbench, + handoff_tools=handoff_tools, + agent_name=agent_name, + cancellation_token=cancellation_token, + output_content_type=output_content_type, + message_id=message_id, + ): + if isinstance(inference_output, CreateResult): + model_result = inference_output + else: + # Streaming chunk event + yield inference_output + + assert model_result is not None, "No model result was produced." + + # --- NEW: If the model produced a hidden "thought," yield it as an event --- + if model_result.thought: + thought_event = ThoughtEvent(content=model_result.thought, source=agent_name) + yield thought_event + inner_messages.append(thought_event) + + # Add the assistant message to the model context (including thought if present) + await model_context.add_message( + AssistantMessage( + content=model_result.content, + source=agent_name, + thought=getattr(model_result, "thought", None), + ) + ) + + # STEP 5: Process the model output + async for output_event in self._process_model_result( + model_result=model_result, + inner_messages=inner_messages, + cancellation_token=cancellation_token, + agent_name=agent_name, + system_messages=system_messages, + model_context=model_context, + workbench=workbench, + handoff_tools=handoff_tools, + handoffs=handoffs, + model_client=model_client, + model_client_stream=model_client_stream, + reflect_on_tool_use=reflect_on_tool_use, + max_tool_iterations=max_tool_iterations, + tool_call_summary_format=tool_call_summary_format, + tool_call_summary_formatter=tool_call_summary_formatter, + output_content_type=output_content_type, + message_id=message_id, + format_string=self._output_content_type_format, + ): + yield output_event + + @staticmethod + async def _add_messages_to_context( + model_context: ChatCompletionContext, + messages: Sequence[BaseChatMessage], + ) -> None: + """ + Add incoming messages to the model context. + """ + for msg in messages: + if isinstance(msg, HandoffMessage): + for llm_msg in msg.context: + await model_context.add_message(llm_msg) + await model_context.add_message(msg.to_model_message()) + + @staticmethod + async def _update_model_context_with_memory( + memory: Optional[Sequence[Memory]], + model_context: ChatCompletionContext, + agent_name: str, + ) -> List[MemoryQueryEvent]: + """Update model context with memory content. + + Args: + memory: Optional sequence of memory stores to query + model_context: Context to update with memory content + agent_name: Name of the agent for event tracking + + Returns: + List of memory query events generated during update + """ + events: List[MemoryQueryEvent] = [] + if memory: + for mem in memory: + update_context_result = await mem.update_context(model_context) + if update_context_result and len(update_context_result.memories.results) > 0: + memory_query_event_msg = MemoryQueryEvent( + content=update_context_result.memories.results, + source=agent_name, + ) + events.append(memory_query_event_msg) + return events + + @classmethod + async def _call_llm( + cls, + model_client: ChatCompletionClient, + model_client_stream: bool, + system_messages: List[SystemMessage], + model_context: ChatCompletionContext, + workbench: Sequence[Workbench], + handoff_tools: List[BaseTool[Any, Any]], + agent_name: str, + cancellation_token: CancellationToken, + output_content_type: type[BaseModel] | None, + message_id: str, + ) -> AsyncGenerator[Union[CreateResult, ModelClientStreamingChunkEvent], None]: + """Call the language model with given context and configuration. + + Args: + model_client: Client for model inference + model_client_stream: Whether to stream responses + system_messages: System messages to include + model_context: Context containing message history + workbench: Available workbenches + handoff_tools: Tools for handling handoffs + agent_name: Name of the agent + cancellation_token: Token for cancelling operation + output_content_type: Optional type for structured output + + Returns: + Generator yielding model results or streaming chunks + """ + all_messages = await model_context.get_messages() + llm_messages = cls._get_compatible_context(model_client=model_client, messages=system_messages + all_messages) + + tools = [tool for wb in workbench for tool in await wb.list_tools()] + handoff_tools + + if model_client_stream: + model_result: Optional[CreateResult] = None + + async for chunk in model_client.create_stream( + llm_messages, + tools=tools, + json_output=output_content_type, + cancellation_token=cancellation_token, + ): + if isinstance(chunk, CreateResult): + model_result = chunk + elif isinstance(chunk, str): + yield ModelClientStreamingChunkEvent(content=chunk, source=agent_name, full_message_id=message_id) + else: + raise RuntimeError(f"Invalid chunk type: {type(chunk)}") + if model_result is None: + raise RuntimeError("No final model result in streaming mode.") + yield model_result + else: + model_result = await model_client.create( + llm_messages, + tools=tools, + cancellation_token=cancellation_token, + json_output=output_content_type, + ) + yield model_result + + @classmethod + async def _process_model_result( + cls, + model_result: CreateResult, + inner_messages: List[BaseAgentEvent | BaseChatMessage], + cancellation_token: CancellationToken, + agent_name: str, + system_messages: List[SystemMessage], + model_context: ChatCompletionContext, + workbench: Sequence[Workbench], + handoff_tools: List[BaseTool[Any, Any]], + handoffs: Dict[str, HandoffBase], + model_client: ChatCompletionClient, + model_client_stream: bool, + reflect_on_tool_use: bool, + tool_call_summary_format: str, + tool_call_summary_formatter: Callable[[FunctionCall, FunctionExecutionResult], str] | None, + max_tool_iterations: int, + output_content_type: type[BaseModel] | None, + message_id: str, + format_string: str | None = None, + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]: + """ + Handle final or partial responses from model_result, including tool calls, handoffs, + and reflection if needed. Supports tool call loops when enabled. + """ + + # Tool call loop implementation with streaming support + current_model_result = model_result + # This variable is needed for the final summary/reflection step + executed_calls_and_results: List[Tuple[FunctionCall, FunctionExecutionResult]] = [] + + for loop_iteration in range(max_tool_iterations): + # If direct text response (string), we're done + if isinstance(current_model_result.content, str): + # Use the passed message ID for the final message + if output_content_type: + content = output_content_type.model_validate_json(current_model_result.content) + yield Response( + chat_message=StructuredMessage[output_content_type]( # type: ignore[valid-type] + content=content, + source=agent_name, + models_usage=current_model_result.usage, + format_string=format_string, + id=message_id, + ), + inner_messages=inner_messages, + ) + else: + yield Response( + chat_message=TextMessage( + content=current_model_result.content, + source=agent_name, + models_usage=current_model_result.usage, + id=message_id, + ), + inner_messages=inner_messages, + ) + return + + # Otherwise, we have function calls + assert isinstance(current_model_result.content, list) and all( + isinstance(item, FunctionCall) for item in current_model_result.content + ) + + # STEP 4A: Yield ToolCallRequestEvent + tool_call_msg = ToolCallRequestEvent( + content=current_model_result.content, + source=agent_name, + models_usage=current_model_result.usage, + ) + event_logger.debug(tool_call_msg) + inner_messages.append(tool_call_msg) + yield tool_call_msg + + # STEP 4B: Execute tool calls with streaming support + # Use a queue to handle streaming results from tool calls. + stream = asyncio.Queue[BaseAgentEvent | BaseChatMessage | None]() + + async def _execute_tool_calls( + function_calls: List[FunctionCall], + stream_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | None], + ) -> List[Tuple[FunctionCall, FunctionExecutionResult]]: + results = await asyncio.gather( + *[ + cls._execute_tool_call( + tool_call=call, + workbench=workbench, + handoff_tools=handoff_tools, + agent_name=agent_name, + cancellation_token=cancellation_token, + stream=stream_queue, + ) + for call in function_calls + ] + ) + # Signal the end of streaming by putting None in the queue. + stream_queue.put_nowait(None) + return results + + task = asyncio.create_task(_execute_tool_calls(current_model_result.content, stream)) + + while True: + event = await stream.get() + if event is None: + # End of streaming, break the loop. + break + if isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage): + yield event + inner_messages.append(event) + else: + raise RuntimeError(f"Unexpected event type: {type(event)}") + + # Wait for all tool calls to complete. + executed_calls_and_results = await task + exec_results = [result for _, result in executed_calls_and_results] + + # Yield ToolCallExecutionEvent + tool_call_result_msg = ToolCallExecutionEvent( + content=exec_results, + source=agent_name, + ) + event_logger.debug(tool_call_result_msg) + await model_context.add_message(FunctionExecutionResultMessage(content=exec_results)) + inner_messages.append(tool_call_result_msg) + yield tool_call_result_msg + + # STEP 4C: Check for handoff + handoff_output = cls._check_and_handle_handoff( + model_result=current_model_result, + executed_calls_and_results=executed_calls_and_results, + inner_messages=inner_messages, + handoffs=handoffs, + agent_name=agent_name, + ) + if handoff_output: + yield handoff_output + return + + # STEP 4D: Check if we should continue the loop. + # If we are on the last iteration, break to the summary/reflection step. + if loop_iteration == max_tool_iterations - 1: + break + + # Continue the loop: make another model call using _call_llm + next_model_result: Optional[CreateResult] = None + async for llm_output in cls._call_llm( + model_client=model_client, + model_client_stream=model_client_stream, + system_messages=system_messages, + model_context=model_context, + workbench=workbench, + handoff_tools=handoff_tools, + agent_name=agent_name, + cancellation_token=cancellation_token, + output_content_type=output_content_type, + message_id=message_id, # Use same message ID for consistency + ): + if isinstance(llm_output, CreateResult): + next_model_result = llm_output + else: + # Streaming chunk event + yield llm_output + + assert next_model_result is not None, "No model result was produced in tool call loop." + current_model_result = next_model_result + + # Yield thought event if present + if current_model_result.thought: + thought_event = ThoughtEvent(content=current_model_result.thought, source=agent_name) + yield thought_event + inner_messages.append(thought_event) + + # Add the assistant message to the model context (including thought if present) + await model_context.add_message( + AssistantMessage( + content=current_model_result.content, + source=agent_name, + thought=getattr(current_model_result, "thought", None), + ) + ) + + # After the loop, reflect or summarize tool results + if reflect_on_tool_use: + async for reflection_response in cls._reflect_on_tool_use_flow( + system_messages=system_messages, + model_client=model_client, + model_client_stream=model_client_stream, + model_context=model_context, + workbench=workbench, + handoff_tools=handoff_tools, + agent_name=agent_name, + inner_messages=inner_messages, + output_content_type=output_content_type, + cancellation_token=cancellation_token, + ): + yield reflection_response + else: + yield cls._summarize_tool_use( + executed_calls_and_results=executed_calls_and_results, + inner_messages=inner_messages, + handoffs=handoffs, + tool_call_summary_format=tool_call_summary_format, + tool_call_summary_formatter=tool_call_summary_formatter, + agent_name=agent_name, + ) + return + + @staticmethod + def _check_and_handle_handoff( + model_result: CreateResult, + executed_calls_and_results: List[Tuple[FunctionCall, FunctionExecutionResult]], + inner_messages: List[BaseAgentEvent | BaseChatMessage], + handoffs: Dict[str, HandoffBase], + agent_name: str, + ) -> Optional[Response]: + """Check for and handle any handoff requests in the model result. + + Args: + model_result: Result from model inference + executed_calls_and_results: List of executed tool calls and their results + inner_messages: List of messages generated during processing + handoffs: Dictionary of available handoff configurations + agent_name: Name of the agent + + Returns: + Optional response containing handoff message if handoff detected + """ + handoff_reqs = [ + call for call in model_result.content if isinstance(call, FunctionCall) and call.name in handoffs + ] + if len(handoff_reqs) > 0: + # We have at least one handoff function call + selected_handoff = handoffs[handoff_reqs[0].name] + + if len(handoff_reqs) > 1: + warnings.warn( + ( + f"Multiple handoffs detected. Only the first is executed: " + f"{[handoffs[c.name].name for c in handoff_reqs]}. " + "Disable parallel tool calls in the model client to avoid this warning." + ), + stacklevel=2, + ) + + # Collect normal tool calls (not handoff) into the handoff context + tool_calls: List[FunctionCall] = [] + tool_call_results: List[FunctionExecutionResult] = [] + # Collect the results returned by handoff_tool. By default, the message attribute will returned. + selected_handoff_message = selected_handoff.message + for exec_call, exec_result in executed_calls_and_results: + if exec_call.name not in handoffs: + tool_calls.append(exec_call) + tool_call_results.append(exec_result) + elif exec_call.name == selected_handoff.name: + selected_handoff_message = exec_result.content + + handoff_context: List[LLMMessage] = [] + if len(tool_calls) > 0: + # Include the thought in the AssistantMessage if model_result has it + handoff_context.append( + AssistantMessage( + content=tool_calls, + source=agent_name, + thought=getattr(model_result, "thought", None), + ) + ) + handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results)) + elif model_result.thought: + # If no tool calls, but a thought exists, include it in the context + handoff_context.append( + AssistantMessage( + content=model_result.thought, + source=agent_name, + ) + ) + + # Return response for the first handoff + return Response( + chat_message=HandoffMessage( + content=selected_handoff_message, + target=selected_handoff.target, + source=agent_name, + context=handoff_context, + ), + inner_messages=inner_messages, + ) + return None + + @classmethod + async def _reflect_on_tool_use_flow( + cls, + system_messages: List[SystemMessage], + model_client: ChatCompletionClient, + model_client_stream: bool, + model_context: ChatCompletionContext, + workbench: Sequence[Workbench], + handoff_tools: List[BaseTool[Any, Any]], + agent_name: str, + inner_messages: List[BaseAgentEvent | BaseChatMessage], + output_content_type: type[BaseModel] | None, + cancellation_token: CancellationToken, + ) -> AsyncGenerator[Response | ModelClientStreamingChunkEvent | ThoughtEvent, None]: + """ + If reflect_on_tool_use=True, we do another inference based on tool results + and yield the final text response (or streaming chunks). + """ + all_messages = system_messages + await model_context.get_messages() + llm_messages = cls._get_compatible_context(model_client=model_client, messages=all_messages) + + reflection_result: Optional[CreateResult] = None + + # Generate a message ID for correlation between chunks and final message in reflection flow + reflection_message_id = str(uuid.uuid4()) + + if model_client_stream: + async for chunk in model_client.create_stream( + llm_messages, + json_output=output_content_type, + cancellation_token=cancellation_token, + tool_choice="none", # Do not use tools in reflection flow. + ): + if isinstance(chunk, CreateResult): + reflection_result = chunk + elif isinstance(chunk, str): + yield ModelClientStreamingChunkEvent( + content=chunk, source=agent_name, full_message_id=reflection_message_id + ) + else: + raise RuntimeError(f"Invalid chunk type: {type(chunk)}") + else: + reflection_result = await model_client.create( + llm_messages, + json_output=output_content_type, + cancellation_token=cancellation_token, + tool_choice="none", # Do not use tools in reflection flow. + ) + + if not reflection_result or not isinstance(reflection_result.content, str): + raise RuntimeError("Reflect on tool use produced no valid text response.") + + # --- NEW: If the reflection produced a thought, yield it --- + if reflection_result.thought: + thought_event = ThoughtEvent(content=reflection_result.thought, source=agent_name) + yield thought_event + inner_messages.append(thought_event) + + # Add to context (including thought if present) + await model_context.add_message( + AssistantMessage( + content=reflection_result.content, + source=agent_name, + thought=getattr(reflection_result, "thought", None), + ) + ) + + if output_content_type: + content = output_content_type.model_validate_json(reflection_result.content) + yield Response( + chat_message=StructuredMessage[output_content_type]( # type: ignore[valid-type] + content=content, + source=agent_name, + models_usage=reflection_result.usage, + id=reflection_message_id, + ), + inner_messages=inner_messages, + ) + else: + yield Response( + chat_message=TextMessage( + content=reflection_result.content, + source=agent_name, + models_usage=reflection_result.usage, + id=reflection_message_id, + ), + inner_messages=inner_messages, + ) + + @staticmethod + def _summarize_tool_use( + executed_calls_and_results: List[Tuple[FunctionCall, FunctionExecutionResult]], + inner_messages: List[BaseAgentEvent | BaseChatMessage], + handoffs: Dict[str, HandoffBase], + tool_call_summary_format: str, + tool_call_summary_formatter: Callable[[FunctionCall, FunctionExecutionResult], str] | None, + agent_name: str, + ) -> Response: + """ + If reflect_on_tool_use=False, create a summary message of all tool calls. + """ + # Filter out calls which were actually handoffs + normal_tool_calls = [(call, result) for call, result in executed_calls_and_results if call.name not in handoffs] + + def default_tool_call_summary_formatter(call: FunctionCall, result: FunctionExecutionResult) -> str: + return tool_call_summary_format.format( + tool_name=call.name, + arguments=call.arguments, + result=result.content, + is_error=result.is_error, + ) + + summary_formatter = tool_call_summary_formatter or default_tool_call_summary_formatter + + tool_call_summaries = [summary_formatter(call, result) for call, result in normal_tool_calls] + + tool_call_summary = "\n".join(tool_call_summaries) + return Response( + chat_message=ToolCallSummaryMessage( + content=tool_call_summary, + source=agent_name, + tool_calls=[call for call, _ in normal_tool_calls], + results=[result for _, result in normal_tool_calls], + ), + inner_messages=inner_messages, + ) + + @staticmethod + async def _execute_tool_call( + tool_call: FunctionCall, + workbench: Sequence[Workbench], + handoff_tools: List[BaseTool[Any, Any]], + agent_name: str, + cancellation_token: CancellationToken, + stream: asyncio.Queue[BaseAgentEvent | BaseChatMessage | None], + ) -> Tuple[FunctionCall, FunctionExecutionResult]: + """Execute a single tool call and return the result.""" + # Load the arguments from the tool call. + try: + arguments = json.loads(tool_call.arguments) + except json.JSONDecodeError as e: + return ( + tool_call, + FunctionExecutionResult( + content=f"Error: {e}", + call_id=tool_call.id, + is_error=True, + name=tool_call.name, + ), + ) + + # Check if the tool call is a handoff. + # TODO: consider creating a combined workbench to handle both handoff and normal tools. + for handoff_tool in handoff_tools: + if tool_call.name == handoff_tool.name: + # Run handoff tool call. + result = await handoff_tool.run_json(arguments, cancellation_token, call_id=tool_call.id) + result_as_str = handoff_tool.return_value_as_string(result) + return ( + tool_call, + FunctionExecutionResult( + content=result_as_str, + call_id=tool_call.id, + is_error=False, + name=tool_call.name, + ), + ) + + # Handle normal tool call using workbench. + for wb in workbench: + tools = await wb.list_tools() + if any(t["name"] == tool_call.name for t in tools): + if isinstance(wb, StaticStreamWorkbench): + tool_result: ToolResult | None = None + async for event in wb.call_tool_stream( + name=tool_call.name, + arguments=arguments, + cancellation_token=cancellation_token, + call_id=tool_call.id, + ): + if isinstance(event, ToolResult): + tool_result = event + elif isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage): + await stream.put(event) + else: + warnings.warn( + f"Unexpected event type: {type(event)} in tool call streaming.", + UserWarning, + stacklevel=2, + ) + assert isinstance(tool_result, ToolResult), "Tool result should not be None in streaming mode." + else: + tool_result = await wb.call_tool( + name=tool_call.name, + arguments=arguments, + cancellation_token=cancellation_token, + call_id=tool_call.id, + ) + return ( + tool_call, + FunctionExecutionResult( + content=tool_result.to_text(), + call_id=tool_call.id, + is_error=tool_result.is_error, + name=tool_call.name, + ), + ) + + return ( + tool_call, + FunctionExecutionResult( + content=f"Error: tool '{tool_call.name}' not found in any workbench", + call_id=tool_call.id, + is_error=True, + name=tool_call.name, + ), + ) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + """Reset the assistant agent to its initialization state.""" + await self._model_context.clear() + + async def save_state(self) -> Mapping[str, Any]: + """Save the current state of the assistant agent.""" + model_context_state = await self._model_context.save_state() + return AssistantAgentState(llm_context=model_context_state).model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + """Load the state of the assistant agent""" + assistant_agent_state = AssistantAgentState.model_validate(state) + # Load the model context state. + await self._model_context.load_state(assistant_agent_state.llm_context) + + @staticmethod + def _get_compatible_context(model_client: ChatCompletionClient, messages: List[LLMMessage]) -> Sequence[LLMMessage]: + """Ensure that the messages are compatible with the underlying client, by removing images if needed.""" + if model_client.model_info["vision"]: + return messages + else: + return remove_images(messages) + + def _to_config(self) -> AssistantAgentConfig: + """Convert the assistant agent to a declarative config.""" + + return AssistantAgentConfig( + name=self.name, + model_client=self._model_client.dump_component(), + tools=None, # versionchanged:: v0.5.5 Now tools are not serialized, Cause they are part of the workbench. + workbench=[wb.dump_component() for wb in self._workbench] if self._workbench else None, + handoffs=list(self._handoffs.values()) if self._handoffs else None, + model_context=self._model_context.dump_component(), + memory=[memory.dump_component() for memory in self._memory] if self._memory else None, + description=self.description, + system_message=self._system_messages[0].content + if self._system_messages and isinstance(self._system_messages[0].content, str) + else None, + model_client_stream=self._model_client_stream, + reflect_on_tool_use=self._reflect_on_tool_use, + max_tool_iterations=self._max_tool_iterations, + tool_call_summary_format=self._tool_call_summary_format, + structured_message_factory=self._structured_message_factory.dump_component() + if self._structured_message_factory + else None, + metadata=self._metadata, + ) + + @classmethod + def _from_config(cls, config: AssistantAgentConfig) -> Self: + """Create an assistant agent from a declarative config.""" + if config.structured_message_factory: + structured_message_factory = StructuredMessageFactory.load_component(config.structured_message_factory) + format_string = structured_message_factory.format_string + output_content_type = structured_message_factory.ContentModel + + else: + format_string = None + output_content_type = None + + return cls( + name=config.name, + model_client=ChatCompletionClient.load_component(config.model_client), + workbench=[Workbench.load_component(wb) for wb in config.workbench] if config.workbench else None, + handoffs=config.handoffs, + model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, + tools=[BaseTool.load_component(tool) for tool in config.tools] if config.tools else None, + memory=[Memory.load_component(memory) for memory in config.memory] if config.memory else None, + description=config.description, + system_message=config.system_message, + model_client_stream=config.model_client_stream, + reflect_on_tool_use=config.reflect_on_tool_use, + max_tool_iterations=config.max_tool_iterations, + tool_call_summary_format=config.tool_call_summary_format, + output_content_type=output_content_type, + output_content_type_format=format_string, + metadata=config.metadata, + ) diff --git a/agent_dhal/agentdhal_agentchat/agents/_base_chat_agent.py b/agent_dhal/agentdhal_agentchat/agents/_base_chat_agent.py new file mode 100644 index 0000000..05da4cd --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/agents/_base_chat_agent.py @@ -0,0 +1,245 @@ +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator, List, Mapping, Sequence + +from agentdhal_core import CancellationToken, ComponentBase, trace_create_agent_span, trace_invoke_agent_span +from pydantic import BaseModel + +from ..base import ChatAgent, Response, TaskResult +from ..messages import ( + BaseAgentEvent, + BaseChatMessage, + ModelClientStreamingChunkEvent, + TextMessage, +) +from ..state import BaseState + + +class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]): + """Base class for a chat agent. + + This abstract class provides a base implementation for a :class:`ChatAgent`. + To create a new chat agent, subclass this class and implement the + :meth:`on_messages`, :meth:`on_reset`, and :attr:`produced_message_types`. + If streaming is required, also implement the :meth:`on_messages_stream` method. + + An agent is considered stateful and maintains its state between calls to + the :meth:`on_messages` or :meth:`on_messages_stream` methods. + The agent should store its state in the + agent instance. The agent should also implement the :meth:`on_reset` method + to reset the agent to its initialization state. + + .. note:: + + The caller should only pass the new messages to the agent on each call + to the :meth:`on_messages` or :meth:`on_messages_stream` method. + Do not pass the entire conversation history to the agent on each call. + This design principle must be followed when creating a new agent. + """ + + component_type = "agent" + + def __init__(self, name: str, description: str) -> None: + """Initialize the agent with a name and description.""" + with trace_create_agent_span( + agent_name=name, + agent_description=description, + ): + self._name = name + if self._name.isidentifier() is False: + raise ValueError("The agent name must be a valid Python identifier.") + self._description = description + + @property + def name(self) -> str: + """The name of the agent. This is used by team to uniquely identify + the agent. It should be unique within the team.""" + return self._name + + @property + def description(self) -> str: + """The description of the agent. This is used by team to + make decisions about which agents to use. The description should + describe the agent's capabilities and how to interact with it.""" + return self._description + + @property + @abstractmethod + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + """The types of messages that the agent produces in the + :attr:`Response.chat_message` field. They must be :class:`BaseChatMessage` types.""" + ... + + @abstractmethod + async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: + """Handles incoming messages and returns a response. + + .. note:: + + Agents are stateful and the messages passed to this method should + be the new messages since the last call to this method. The agent + should maintain its state between calls to this method. For example, + if the agent needs to remember the previous messages to respond to + the current message, it should store the previous messages in the + agent state. + + """ + ... + + async def on_messages_stream( + self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]: + """Handles incoming messages and returns a stream of messages and + and the final item is the response. The base implementation in + :class:`BaseChatAgent` simply calls :meth:`on_messages` and yields + the messages in the response. + + .. note:: + + Agents are stateful and the messages passed to this method should + be the new messages since the last call to this method. The agent + should maintain its state between calls to this method. For example, + if the agent needs to remember the previous messages to respond to + the current message, it should store the previous messages in the + agent state. + + """ + response = await self.on_messages(messages, cancellation_token) + for inner_message in response.inner_messages or []: + yield inner_message + yield response + + async def run( + self, + *, + task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None, + cancellation_token: CancellationToken | None = None, + output_task_messages: bool = True, + ) -> TaskResult: + """Run the agent with the given task and return the result.""" + with trace_invoke_agent_span( + agent_name=self.name, + agent_description=self.description, + ): + if cancellation_token is None: + cancellation_token = CancellationToken() + input_messages: List[BaseChatMessage] = [] + output_messages: List[BaseAgentEvent | BaseChatMessage] = [] + if task is None: + pass + elif isinstance(task, str): + text_msg = TextMessage(content=task, source="user") + input_messages.append(text_msg) + if output_task_messages: + output_messages.append(text_msg) + elif isinstance(task, BaseChatMessage): + input_messages.append(task) + if output_task_messages: + output_messages.append(task) + else: + if not task: + raise ValueError("Task list cannot be empty.") + # Task is a sequence of messages. + for msg in task: + if isinstance(msg, BaseChatMessage): + input_messages.append(msg) + if output_task_messages: + output_messages.append(msg) + else: + raise ValueError(f"Invalid message type in sequence: {type(msg)}") + response = await self.on_messages(input_messages, cancellation_token) + if response.inner_messages is not None: + output_messages += response.inner_messages + output_messages.append(response.chat_message) + return TaskResult(messages=output_messages) + + async def run_stream( + self, + *, + task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None, + cancellation_token: CancellationToken | None = None, + output_task_messages: bool = True, + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]: + """Run the agent with the given task and return a stream of messages + and the final task result as the last item in the stream. + + Args: + task: The task to run. Can be a string, a single message, or a sequence of messages. + cancellation_token: The cancellation token to kill the task immediately. + output_task_messages: Whether to include task messages in the output stream. Defaults to True for backward compatibility. + """ + with trace_invoke_agent_span( + agent_name=self.name, + agent_description=self.description, + ): + if cancellation_token is None: + cancellation_token = CancellationToken() + input_messages: List[BaseChatMessage] = [] + output_messages: List[BaseAgentEvent | BaseChatMessage] = [] + if task is None: + pass + elif isinstance(task, str): + text_msg = TextMessage(content=task, source="user") + input_messages.append(text_msg) + if output_task_messages: + output_messages.append(text_msg) + yield text_msg + elif isinstance(task, BaseChatMessage): + input_messages.append(task) + if output_task_messages: + output_messages.append(task) + yield task + else: + if not task: + raise ValueError("Task list cannot be empty.") + for msg in task: + if isinstance(msg, BaseChatMessage): + input_messages.append(msg) + if output_task_messages: + output_messages.append(msg) + yield msg + else: + raise ValueError(f"Invalid message type in sequence: {type(msg)}") + async for message in self.on_messages_stream(input_messages, cancellation_token): + if isinstance(message, Response): + yield message.chat_message + output_messages.append(message.chat_message) + yield TaskResult(messages=output_messages) + else: + yield message + if isinstance(message, ModelClientStreamingChunkEvent): + # Skip the model client streaming chunk events. + continue + output_messages.append(message) + + @abstractmethod + async def on_reset(self, cancellation_token: CancellationToken) -> None: + """Resets the agent to its initialization state.""" + ... + + async def on_pause(self, cancellation_token: CancellationToken) -> None: + """Called when the agent is paused while running in its :meth:`on_messages` or + :meth:`on_messages_stream` method. This is a no-op by default in the + :class:`BaseChatAgent` class. Subclasses can override this method to + implement custom pause behavior.""" + pass + + async def on_resume(self, cancellation_token: CancellationToken) -> None: + """Called when the agent is resumed from a pause while running in + its :meth:`on_messages` or :meth:`on_messages_stream` method. + This is a no-op by default in the :class:`BaseChatAgent` class. + Subclasses can override this method to implement custom resume behavior.""" + pass + + async def save_state(self) -> Mapping[str, Any]: + """Export state. Default implementation for stateless agents.""" + return BaseState().model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + """Restore agent from saved state. Default implementation for stateless agents.""" + BaseState.model_validate(state) + + async def close(self) -> None: + """Release any resources held by the agent. This is a no-op by default in the + :class:`BaseChatAgent` class. Subclasses can override this method to + implement custom close behavior.""" + pass diff --git a/agent_dhal/agentdhal_agentchat/agents/_code_executor_agent.py b/agent_dhal/agentdhal_agentchat/agents/_code_executor_agent.py new file mode 100644 index 0000000..a6c3820 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/agents/_code_executor_agent.py @@ -0,0 +1,881 @@ +import logging +import re +from inspect import iscoroutinefunction +from typing import ( + AsyncGenerator, + Awaitable, + Callable, + List, + Optional, + Sequence, + Union, + cast, +) + +from agentdhal_core import CancellationToken, Component, ComponentModel +from agentdhal_core.code_executor import CodeBlock, CodeExecutor, CodeResult +from agentdhal_core.model_context import ( + ChatCompletionContext, + UnboundedChatCompletionContext, +) +from agentdhal_core.models import ( + AssistantMessage, + ChatCompletionClient, + CreateResult, + LLMMessage, + SystemMessage, + UserMessage, +) +from pydantic import BaseModel +from typing_extensions import Self + +from .. import EVENT_LOGGER_NAME +from ..base import Response +from ..messages import ( + BaseAgentEvent, + BaseChatMessage, + CodeExecutionEvent, + CodeGenerationEvent, + HandoffMessage, + ModelClientStreamingChunkEvent, + TextMessage, + ThoughtEvent, +) +from ..utils import remove_images +from ._base_chat_agent import BaseChatAgent + +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + + +class CodeExecutorAgentConfig(BaseModel): + """Configuration for CodeExecutorAgent""" + + name: str + code_executor: ComponentModel + model_client: ComponentModel | None = None + description: str | None = None + sources: List[str] | None = None + system_message: str | None = None + model_client_stream: bool = False + model_context: ComponentModel | None = None + supported_languages: List[str] | None = None + + +class RetryDecision(BaseModel): + reason: str + retry: bool + + +class ApprovalRequest(BaseModel): + """Request for approval of code execution.""" + + code: str + context: List[LLMMessage] + + +class ApprovalResponse(BaseModel): + """Response to approval request.""" + + approved: bool + reason: str + + +# Type aliases for approval functions +SyncApprovalFunc = Callable[[ApprovalRequest], ApprovalResponse] +AsyncApprovalFunc = Callable[[ApprovalRequest], Awaitable[ApprovalResponse]] +ApprovalFuncType = Union[SyncApprovalFunc, AsyncApprovalFunc] + + +class CodeExecutorAgent(BaseChatAgent, Component[CodeExecutorAgentConfig]): + """(Experimental) An agent that generates and executes code snippets based on user instructions. + + .. note:: + + This agent is experimental and may change in future releases. + + It is typically used within a team with another agent that generates code snippets + to be executed or alone with `model_client` provided so that it can generate code + based on user query, execute it and reflect on the code result. + + When used with `model_client`, it will generate code snippets using the model + and execute them using the provided `code_executor`. The model will also reflect on the + code execution results. The agent will yield the final reflection result from the model + as the final response. + + When used without `model_client`, it will only execute code blocks found in + :class:`~agentdhal_agentchat.messages.TextMessage` messages and returns the output + of the code execution. + + .. note:: + + Using :class:`~agentdhal_agentchat.agents.AssistantAgent` with + :class:`~agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool` + is an alternative to this agent. However, the model for that agent will + have to generate properly escaped code string as a parameter to the tool. + + Args: + name (str): The name of the agent. + code_executor (CodeExecutor): The code executor responsible for executing code received in messages + (:py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor` recommended. See example below) + model_client (ChatCompletionClient, optional): The model client to use for inference and generating code. + If not provided, the agent will only execute code blocks found in input messages. + Currently, the model must support structured output mode, which is required for + the automatic retry mechanism to work. + model_client_stream (bool, optional): If `True`, the model client will be used in streaming mode. + :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will + also yield :class:`~agentdhal_agentchat.messages.ModelClientStreamingChunkEvent` + messages as the model client produces chunks of response. Defaults to `False`. + description (str, optional): The description of the agent. If not provided, + :class:`~agentdhal_agentchat.agents.CodeExecutorAgent.DEFAULT_AGENT_DESCRIPTION` will be used. + system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable. + Defaults to :class:`~agentdhal_agentchat.agents.CodeExecutorAgent.DEFAULT_SYSTEM_MESSAGE`. This is only used if `model_client` is provided. + sources (Sequence[str], optional): Check only messages from the specified agents for the code to execute. + This is useful when the agent is part of a group chat and you want to limit the code execution to messages from specific agents. + If not provided, all messages will be checked for code blocks. + This is only used if `model_client` is not provided. + max_retries_on_error (int, optional): The maximum number of retries on error. If the code execution fails, the agent will retry up to this number of times. + If the code execution fails after this number of retries, the agent will yield a reflection result. + supported_languages (List[str], optional): List of programming languages that will be parsed and executed from agent response; + others will be ignored. Defaults to DEFAULT_SUPPORTED_LANGUAGES. + approval_func (Optional[Union[Callable[[ApprovalRequest], ApprovalResponse], Callable[[ApprovalRequest], Awaitable[ApprovalResponse]]]], optional): A function that is called before each code execution to get approval. + The function takes an ApprovalRequest containing the code to be executed and the current context, and returns an ApprovalResponse. + The function can be either synchronous or asynchronous. If None (default), all code executions are automatically approved. + If set, the agent cannot be serialized using :meth:`~agentdhal_agentchat.agents.CodeExecutorAgent.dump_component`. + + + .. note:: + + It is recommended that the `CodeExecutorAgent` agent uses a Docker container to execute code. This ensures that model-generated code is executed in an isolated environment. To use Docker, your environment must have Docker installed and running. + Follow the installation instructions for `Docker `_. + + .. note:: + + The code executor only processes code that is properly formatted in markdown code blocks using triple backticks. + For example: + + .. code-block:: text + + ```python + print("Hello World") + ``` + + # or + + ```sh + echo "Hello World" + ``` + + In this example, we show how to set up a `CodeExecutorAgent` agent that uses the + :py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor` + to execute code snippets in a Docker container. The `work_dir` parameter indicates + where all executed files are first saved locally before being executed in the Docker container. + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import CodeExecutorAgent, ApprovalRequest, ApprovalResponse + from agentdhal_agentchat.messages import TextMessage + from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor + from agentdhal_core import CancellationToken + + + def simple_approval_func(request: ApprovalRequest) -> ApprovalResponse: + \"\"\"Simple approval function that requests user input for code execution approval.\"\"\" + print("Code execution approval requested:") + print("=" * 50) + print(request.code) + print("=" * 50) + + while True: + user_input = input("Do you want to execute this code? (y/n): ").strip().lower() + if user_input in ['y', 'yes']: + return ApprovalResponse(approved=True, reason='Approved by user') + elif user_input in ['n', 'no']: + return ApprovalResponse(approved=False, reason='Denied by user') + else: + print("Please enter 'y' for yes or 'n' for no.") + + + async def run_code_executor_agent() -> None: + # Create a code executor agent that uses a Docker container to execute code. + code_executor = DockerCommandLineCodeExecutor(work_dir="coding") + await code_executor.start() + code_executor_agent = CodeExecutorAgent( + "code_executor", + code_executor=code_executor, + approval_func=simple_approval_func + ) + + # Run the agent with a given code snippet. + task = TextMessage( + content='''Here is some code + ```python + print('Hello world') + ``` + ''', + source="user", + ) + response = await code_executor_agent.on_messages([task], CancellationToken()) + print(response.chat_message) + + # Stop the code executor. + await code_executor.stop() + + + asyncio.run(run_code_executor_agent()) + + In this example, we show how to set up a `CodeExecutorAgent` agent that uses the + :py:class:`~docker.types.DeviceRequest` to expose a GPU to the container for cuda-accelerated code execution. + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import CodeExecutorAgent + from agentdhal_agentchat.messages import TextMessage + from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor + from agentdhal_core import CancellationToken + from docker.types import DeviceRequest + + + async def run_code_executor_agent() -> None: + # Create a code executor agent that uses a Docker container to execute code. + code_executor = DockerCommandLineCodeExecutor( + work_dir="coding", device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])] + ) + await code_executor.start() + code_executor_agent = CodeExecutorAgent("code_executor", code_executor=code_executor) + + # Display the GPU information + task = TextMessage( + content='''Here is some code + ```sh + nvidia-smi + ``` + ''', + source="user", + ) + response = await code_executor_agent.on_messages([task], CancellationToken()) + print(response.chat_message) + + # Stop the code executor. + await code_executor.stop() + + + asyncio.run(run_code_executor_agent()) + + In the following example, we show how to setup `CodeExecutorAgent` without `model_client` parameter for executing code blocks generated by other agents in a group chat using :py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor` + + .. code-block:: python + + import asyncio + + from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + from agentdhal_agentchat.agents import AssistantAgent, CodeExecutorAgent, ApprovalRequest, ApprovalResponse + from agentdhal_agentchat.conditions import MaxMessageTermination + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_agentchat.ui import Console + + termination_condition = MaxMessageTermination(3) + + + def group_chat_approval_func(request: ApprovalRequest) -> ApprovalResponse: + \"\"\"Approval function for group chat that allows basic Python operations.\"\"\" + # Allow common safe operations + safe_operations = ["print(", "import ", "def ", "class ", "if ", "for ", "while "] + if any(op in request.code for op in safe_operations): + return ApprovalResponse(approved=True, reason='Safe Python operation') + + # Deny file system operations in group chat + dangerous_operations = ["open(", "file(", "os.", "subprocess", "eval(", "exec("] + if any(op in request.code for op in dangerous_operations): + return ApprovalResponse(approved=False, reason='File system or dangerous operation not allowed') + + return ApprovalResponse(approved=True, reason='Operation approved') + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + # define the Docker CLI Code Executor + code_executor = DockerCommandLineCodeExecutor(work_dir="coding") + + # start the execution container + await code_executor.start() + + code_executor_agent = CodeExecutorAgent( + "code_executor_agent", + code_executor=code_executor, + approval_func=group_chat_approval_func + ) + coder_agent = AssistantAgent("coder_agent", model_client=model_client) + + groupchat = RoundRobinGroupChat( + participants=[coder_agent, code_executor_agent], termination_condition=termination_condition + ) + + task = "Write python code to print Hello World!" + await Console(groupchat.run_stream(task=task)) + + # stop the execution container + await code_executor.stop() + + + asyncio.run(main()) + + .. code-block:: text + + ---------- user ---------- + Write python code to print Hello World! + ---------- coder_agent ---------- + Certainly! Here's a simple Python code to print "Hello World!": + + ```python + print("Hello World!") + ``` + + You can run this code in any Python environment to display the message. + ---------- code_executor_agent ---------- + Hello World! + + In the following example, we show how to setup `CodeExecutorAgent` with `model_client` + that can generate its own code without the help of any other agent and executing it in + :py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor`. + It also demonstrates using a model-based approval function that reviews the code for safety before execution. + + .. code-block:: python + + import asyncio + + from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_core.models import SystemMessage, UserMessage + + from agentdhal_agentchat.agents import CodeExecutorAgent, ApprovalRequest, ApprovalResponse + from agentdhal_agentchat.conditions import TextMessageTermination + from agentdhal_agentchat.ui import Console + + termination_condition = TextMessageTermination("code_executor_agent") + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + async def model_client_approval_func(request: ApprovalRequest) -> ApprovalResponse: + instruction = "Approve or reject the code in the last message based on whether it is dangerous or not. Use the following JSON format for your response: {approved: true/false, reason: 'your reason here'}" + response = await model_client.create( + messages=[SystemMessage(content=instruction)] + + request.context + + [UserMessage(content=request.code, source="user")], + json_output=ApprovalResponse, + ) + assert isinstance(response.content, str) + return ApprovalResponse.model_validate_json(response.content) + + # define the Docker CLI Code Executor + code_executor = DockerCommandLineCodeExecutor(work_dir="coding") + + # start the execution container + await code_executor.start() + + code_executor_agent = CodeExecutorAgent( + "code_executor_agent", + code_executor=code_executor, + model_client=model_client, + approval_func=model_client_approval_func, + ) + + task = "Write python code to print Hello World!" + await Console(code_executor_agent.run_stream(task=task)) + + # stop the execution container + await code_executor.stop() + + + asyncio.run(main()) + + + .. code-block:: text + + ---------- user ---------- + Write python code to print Hello World! + ---------- code_executor_agent ---------- + Certainly! Here is a simple Python code to print "Hello World!" to the console: + + ```python + print("Hello World!") + ``` + + Let's execute it to confirm the output. + ---------- code_executor_agent ---------- + Hello World! + + ---------- code_executor_agent ---------- + The code has been executed successfully, and it printed "Hello World!" as expected. If you have any more requests or questions, feel free to ask! + + """ + + DEFAULT_TERMINAL_DESCRIPTION = "A computer terminal that performs no other action than running Python scripts (provided to it quoted in ```python code blocks), or sh shell scripts (provided to it quoted in ```sh code blocks)." + DEFAULT_AGENT_DESCRIPTION = "A Code Execution Agent that generates and executes Python and shell scripts based on user instructions. It ensures correctness, efficiency, and minimal errors while gracefully handling edge cases." + DEFAULT_SYSTEM_MESSAGE = "You are a Code Execution Agent. Your role is to generate and execute Python code and shell scripts based on user instructions, ensuring correctness, efficiency, and minimal errors. Handle edge cases gracefully. Python code should be provided in ```python code blocks, and sh shell scripts should be provided in ```sh code blocks for execution." + NO_CODE_BLOCKS_FOUND_MESSAGE = "No code blocks found in the thread. Please provide at least one markdown-encoded code block to execute (i.e., quoting code in ```python or ```sh code blocks)." + DEFAULT_SUPPORTED_LANGUAGES = ["python", "sh"] + + component_config_schema = CodeExecutorAgentConfig + component_provider_override = "agentdhal_agentchat.agents.CodeExecutorAgent" + + def __init__( + self, + name: str, + code_executor: CodeExecutor, + *, + model_client: ChatCompletionClient | None = None, + model_context: ChatCompletionContext | None = None, + model_client_stream: bool = False, + max_retries_on_error: int = 0, + description: str | None = None, + system_message: str | None = DEFAULT_SYSTEM_MESSAGE, + sources: Sequence[str] | None = None, + supported_languages: List[str] | None = None, + approval_func: Optional[ApprovalFuncType] = None, + ) -> None: + if description is None: + if model_client is None: + description = CodeExecutorAgent.DEFAULT_TERMINAL_DESCRIPTION + else: + description = CodeExecutorAgent.DEFAULT_AGENT_DESCRIPTION + + super().__init__(name=name, description=description) + self._code_executor = code_executor + self._sources = sources + self._model_client_stream = model_client_stream + self._max_retries_on_error = max_retries_on_error + self._approval_func = approval_func + self._approval_func_is_async = approval_func is not None and iscoroutinefunction(approval_func) + + if supported_languages is not None: + self._supported_languages = supported_languages + else: + self._supported_languages = CodeExecutorAgent.DEFAULT_SUPPORTED_LANGUAGES + + self._supported_languages_regex = "|".join(re.escape(lang) for lang in self._supported_languages) + + self._model_client = None + if model_client is not None: + self._model_client = model_client + + if model_context is not None: + self._model_context = model_context + else: + self._model_context = UnboundedChatCompletionContext() + + self._system_messaages: List[SystemMessage] = [] + if system_message is None: + self._system_messages = [] + else: + self._system_messages = [SystemMessage(content=system_message)] + + if self._max_retries_on_error > 0: + if not self._model_client or not self._model_client.model_info: + raise ValueError("model_client.model_info must be provided when max_retries_on_error > 0") + if not self._model_client.model_info["structured_output"]: + raise ValueError("Specified model_client doesn't support structured output mode.") + + @property + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + """The types of messages that the code executor agent produces.""" + return (TextMessage,) + + @property + def model_context(self) -> ChatCompletionContext: + """ + The model context in use by the agent. + """ + return self._model_context + + async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: + async for message in self.on_messages_stream(messages, cancellation_token): + if isinstance(message, Response): + return message + raise AssertionError("The stream should have returned the final result.") + + async def on_messages_stream( + self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]: + """ + Process the incoming messages with the assistant agent and yield events/responses as they happen. + """ + + # Gather all relevant state here + agent_name = self.name + model_context = self._model_context + system_messages = self._system_messages + model_client = self._model_client + model_client_stream = self._model_client_stream + max_retries_on_error = self._max_retries_on_error + + execution_result: CodeResult | None = None + if model_client is None: # default behaviour for backward compatibility + # execute generated code if present + code_blocks: List[CodeBlock] = await self.extract_code_blocks_from_messages(messages) + if not code_blocks: + yield Response( + chat_message=TextMessage( + content=self.NO_CODE_BLOCKS_FOUND_MESSAGE, + source=agent_name, + ) + ) + return + execution_result = await self.execute_code_block(code_blocks, cancellation_token) + yield Response(chat_message=TextMessage(content=execution_result.output, source=self.name)) + return + + inner_messages: List[BaseAgentEvent | BaseChatMessage] = [] + + for nth_try in range(max_retries_on_error + 1): # Do one default generation, execution and inference loop + # Step 1: Add new user/handoff messages to the model context + await self._add_messages_to_context( + model_context=model_context, + messages=messages, + ) + + # Step 2: Run inference with the model context + model_result = None + async for inference_output in self._call_llm( + model_client=model_client, + model_client_stream=model_client_stream, + system_messages=system_messages, + model_context=model_context, + agent_name=agent_name, + cancellation_token=cancellation_token, + ): + if isinstance(inference_output, CreateResult): + model_result = inference_output + else: + # Streaming chunk event + yield inference_output + + assert model_result is not None, "No model result was produced." + + # Step 3: [NEW] If the model produced a hidden "thought," yield it as an event + if model_result.thought: + thought_event = ThoughtEvent(content=model_result.thought, source=agent_name) + yield thought_event + inner_messages.append(thought_event) + + # Step 4: Add the assistant message to the model context (including thought if present) + await model_context.add_message( + AssistantMessage( + content=model_result.content, + source=agent_name, + thought=getattr(model_result, "thought", None), + ) + ) + + # Step 5: Extract the code blocks from inferred text + assert isinstance(model_result.content, str), "Expected inferred model_result.content to be of type str." + code_blocks = self._extract_markdown_code_blocks(str(model_result.content)) + + # Step 6: Exit the loop if no code blocks found + if not code_blocks: + yield Response( + chat_message=TextMessage( + content=str(model_result.content), + source=agent_name, + ) + ) + return + + # Step 7: Yield a CodeGenerationEvent + inferred_text_message: CodeGenerationEvent = CodeGenerationEvent( + retry_attempt=nth_try, + content=model_result.content, + code_blocks=code_blocks, + source=agent_name, + ) + + yield inferred_text_message + + # Step 8: Execute the extracted code blocks + execution_result = await self.execute_code_block(inferred_text_message.code_blocks, cancellation_token) + + # Step 9: Update model context with the code execution result + await model_context.add_message( + UserMessage( + content=execution_result.output, + source=agent_name, + ) + ) + + # Step 10: Yield a CodeExecutionEvent + yield CodeExecutionEvent(retry_attempt=nth_try, result=execution_result, source=self.name) + + # If execution was successful or last retry, then exit + if execution_result.exit_code == 0 or nth_try == max_retries_on_error: + break + + # Step 11: If exit code is non-zero and retries are available then + # make an inference asking if we should retry or not + chat_context = await model_context.get_messages() + + retry_prompt = ( + f"The most recent code execution resulted in an error:\n{execution_result.output}\n\n" + "Should we attempt to resolve it? Please respond with:\n" + "- A boolean value for 'retry' indicating whether it should be retried.\n" + "- A detailed explanation in 'reason' that identifies the issue, justifies your decision to retry or not, and outlines how you would resolve the error if a retry is attempted." + ) + + chat_context = chat_context + [ + UserMessage( + content=retry_prompt, + source=agent_name, + ) + ] + + response = await model_client.create(messages=chat_context, json_output=RetryDecision) + + assert isinstance( + response.content, str + ), "Expected structured response for retry decision to be of type str." + should_retry_generation = RetryDecision.model_validate_json(str(response.content)) + + # Exit if no-retry is needed + if not should_retry_generation.retry: + break + + yield CodeGenerationEvent( + retry_attempt=nth_try, + content=f"Attempt number: {nth_try + 1}\nProposed correction: {should_retry_generation.reason}", + code_blocks=[], + source=agent_name, + ) + + # Always reflect on the execution result + async for reflection_response in CodeExecutorAgent._reflect_on_code_block_results_flow( + system_messages=system_messages, + model_client=model_client, + model_client_stream=model_client_stream, + model_context=model_context, + agent_name=agent_name, + inner_messages=inner_messages, + ): + yield reflection_response # Last reflection_response is of type Response so it will finish the routine + + async def extract_code_blocks_from_messages(self, messages: Sequence[BaseChatMessage]) -> List[CodeBlock]: + # Extract code blocks from the messages. + code_blocks: List[CodeBlock] = [] + for msg in messages: + if self._sources is None or msg.source in self._sources: + if isinstance(msg, TextMessage): + code_blocks.extend(self._extract_markdown_code_blocks(msg.content)) + # TODO: handle other message types if needed + return code_blocks + + async def execute_code_block( + self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken + ) -> CodeResult: + # Check for approval before executing code blocks + if self._approval_func is not None: + # Combine all code blocks into a single string for approval + combined_code = "\n\n".join([f"```{block.language}\n{block.code}\n```" for block in code_blocks]) + + # Get the current context from model_context + context_messages = await self._model_context.get_messages() + + # Create approval request + approval_request = ApprovalRequest(code=combined_code, context=context_messages) + + # Get approval (handle both sync and async functions) + if self._approval_func_is_async: + # Cast to AsyncApprovalFunc for proper typing + async_func = cast(AsyncApprovalFunc, self._approval_func) + approval_response = await async_func(approval_request) + else: + # Cast to SyncApprovalFunc for proper typing + sync_func = cast(SyncApprovalFunc, self._approval_func) + approval_response = sync_func(approval_request) + + # If not approved, return error result + if not approval_response.approved: + return CodeResult( + exit_code=1, output=f"Code execution was not approved. Reason: {approval_response.reason}" + ) + + # Execute the code blocks. + result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token) + + if result.output.strip() == "": + # No output + result.output = f"The script ran but produced no output to console. The POSIX exit code was: {result.exit_code}. If you were expecting output, consider revising the script to ensure content is printed to stdout." + elif result.exit_code != 0: + # Error + result.output = f"The script ran, then exited with an error (POSIX exit code: {result.exit_code})\nIts output was:\n{result.output}" + + return result + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + """Its a no-op as the code executor agent has no mutable state.""" + pass + + def _extract_markdown_code_blocks(self, markdown_text: str) -> List[CodeBlock]: + pattern = re.compile(rf"```(?:\s*({self._supported_languages_regex}))\n([\s\S]*?)```", re.IGNORECASE) + matches = pattern.findall(markdown_text) + code_blocks: List[CodeBlock] = [] + for match in matches: + language = match[0].strip() if match[0] else "" + code_content = match[1] + code_blocks.append(CodeBlock(code=code_content, language=language)) + return code_blocks + + def _to_config(self) -> CodeExecutorAgentConfig: + if self._approval_func is not None: + raise ValueError( + "Cannot serialize CodeExecutorAgent with approval_func set. The approval function is not serializable." + ) + + return CodeExecutorAgentConfig( + name=self.name, + model_client=(self._model_client.dump_component() if self._model_client is not None else None), + code_executor=self._code_executor.dump_component(), + description=self.description, + sources=list(self._sources) if self._sources is not None else None, + system_message=( + self._system_messages[0].content + if self._system_messages and isinstance(self._system_messages[0].content, str) + else None + ), + model_client_stream=self._model_client_stream, + model_context=self._model_context.dump_component(), + supported_languages=self._supported_languages, + ) + + @classmethod + def _from_config(cls, config: CodeExecutorAgentConfig) -> Self: + return cls( + name=config.name, + model_client=( + ChatCompletionClient.load_component(config.model_client) if config.model_client is not None else None + ), + code_executor=CodeExecutor.load_component(config.code_executor), + description=config.description, + sources=config.sources, + system_message=config.system_message, + model_client_stream=config.model_client_stream, + model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, + supported_languages=config.supported_languages, + approval_func=None, # approval_func cannot be serialized, so it's always None when loading from config + ) + + @staticmethod + def _get_compatible_context(model_client: ChatCompletionClient, messages: List[LLMMessage]) -> Sequence[LLMMessage]: + """Ensure that the messages are compatible with the underlying client, by removing images if needed.""" + if model_client.model_info["vision"]: + return messages + else: + return remove_images(messages) + + @classmethod + async def _call_llm( + cls, + model_client: ChatCompletionClient, + model_client_stream: bool, + system_messages: List[SystemMessage], + model_context: ChatCompletionContext, + agent_name: str, + cancellation_token: CancellationToken, + ) -> AsyncGenerator[Union[CreateResult, ModelClientStreamingChunkEvent], None]: + """ + Perform a model inference and yield either streaming chunk events or the final CreateResult. + """ + all_messages = await model_context.get_messages() + llm_messages = cls._get_compatible_context(model_client=model_client, messages=system_messages + all_messages) + + if model_client_stream: + model_result: Optional[CreateResult] = None + async for chunk in model_client.create_stream( + llm_messages, tools=[], cancellation_token=cancellation_token + ): + if isinstance(chunk, CreateResult): + model_result = chunk + elif isinstance(chunk, str): + yield ModelClientStreamingChunkEvent(content=chunk, source=agent_name) + else: + raise RuntimeError(f"Invalid chunk type: {type(chunk)}") + if model_result is None: + raise RuntimeError("No final model result in streaming mode.") + yield model_result + else: + model_result = await model_client.create(llm_messages, tools=[], cancellation_token=cancellation_token) + yield model_result + + @staticmethod + async def _add_messages_to_context( + model_context: ChatCompletionContext, + messages: Sequence[BaseChatMessage], + ) -> None: + """ + Add incoming messages to the model context. + """ + for msg in messages: + if isinstance(msg, HandoffMessage): + for llm_msg in msg.context: + await model_context.add_message(llm_msg) + await model_context.add_message(msg.to_model_message()) + + @classmethod + async def _reflect_on_code_block_results_flow( + cls, + system_messages: List[SystemMessage], + model_client: ChatCompletionClient, + model_client_stream: bool, + model_context: ChatCompletionContext, + agent_name: str, + inner_messages: List[BaseAgentEvent | BaseChatMessage], + ) -> AsyncGenerator[Response | ModelClientStreamingChunkEvent | ThoughtEvent, None]: + """ + If reflect_on_code_block_results=True, we do another inference based on tool results + and yield the final text response (or streaming chunks). + """ + all_messages = system_messages + await model_context.get_messages() + llm_messages = cls._get_compatible_context(model_client=model_client, messages=all_messages) + + reflection_result: Optional[CreateResult] = None + + if model_client_stream: + async for chunk in model_client.create_stream(llm_messages): + if isinstance(chunk, CreateResult): + reflection_result = chunk + elif isinstance(chunk, str): + yield ModelClientStreamingChunkEvent(content=chunk, source=agent_name) + else: + raise RuntimeError(f"Invalid chunk type: {type(chunk)}") + else: + reflection_result = await model_client.create(llm_messages) + + if not reflection_result or not isinstance(reflection_result.content, str): + raise RuntimeError("Reflect on tool use produced no valid text response.") + + # --- NEW: If the reflection produced a thought, yield it --- + if reflection_result.thought: + thought_event = ThoughtEvent(content=reflection_result.thought, source=agent_name) + yield thought_event + inner_messages.append(thought_event) + + # Add to context (including thought if present) + await model_context.add_message( + AssistantMessage( + content=reflection_result.content, + source=agent_name, + thought=getattr(reflection_result, "thought", None), + ) + ) + + yield Response( + chat_message=TextMessage( + content=reflection_result.content, + source=agent_name, + models_usage=reflection_result.usage, + ), + inner_messages=inner_messages, + ) diff --git a/agent_dhal/agentdhal_agentchat/agents/_message_filter_agent.py b/agent_dhal/agentdhal_agentchat/agents/_message_filter_agent.py new file mode 100644 index 0000000..2e16548 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/agents/_message_filter_agent.py @@ -0,0 +1,203 @@ +from typing import AsyncGenerator, List, Literal, Optional, Sequence, Union + +from agentdhal_core import CancellationToken, Component, ComponentModel +from pydantic import BaseModel + +from agentdhal_agentchat.agents import BaseChatAgent +from agentdhal_agentchat.base import Response +from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage + +# ------------------------------ +# Message Filter Config +# ------------------------------ + + +class PerSourceFilter(BaseModel): + source: str + position: Optional[Literal["first", "last"]] = None + count: Optional[int] = None + + +class MessageFilterConfig(BaseModel): + per_source: List[PerSourceFilter] + + +# ------------------------------ +# Component Config +# ------------------------------ + + +class MessageFilterAgentConfig(BaseModel): + name: str + wrapped_agent: ComponentModel + filter: MessageFilterConfig + + +# ------------------------------ +# Message Filter Agent +# ------------------------------ + + +class MessageFilterAgent(BaseChatAgent, Component[MessageFilterAgentConfig]): + """ + A wrapper agent that filters incoming messages before passing them to the inner agent. + + .. warning:: + + This is an experimental feature, and the API will change in the future releases. + + This is useful in scenarios like multi-agent workflows where an agent should only + process a subset of the full message history—for example, only the last message + from each upstream agent, or only the first message from a specific source. + + Filtering is configured using :class:`MessageFilterConfig`, which supports: + - Filtering by message source (e.g., only messages from "user" or another agent) + - Selecting the first N or last N messages from each source + - If position is `None`, all messages from that source are included + + This agent is compatible with both direct message passing and team-based execution + such as :class:`~agentdhal_agentchat.teams.GraphFlow`. + + Example: + >>> agent_a = MessageFilterAgent( + ... name="A", + ... wrapped_agent=some_other_agent, + ... filter=MessageFilterConfig( + ... per_source=[ + ... PerSourceFilter(source="user", position="first", count=1), + ... PerSourceFilter(source="B", position="last", count=2), + ... ] + ... ), + ... ) + + Example use case with Graph: + Suppose you have a looping multi-agent graph: A → B → A → B → C. + + You want: + - A to only see the user message and the last message from B + - B to see the user message, last message from A, and its own prior responses (for reflection) + - C to see the user message and the last message from B + + Wrap the agents like so: + + >>> agent_a = MessageFilterAgent( + ... name="A", + ... wrapped_agent=agent_a_inner, + ... filter=MessageFilterConfig( + ... per_source=[ + ... PerSourceFilter(source="user", position="first", count=1), + ... PerSourceFilter(source="B", position="last", count=1), + ... ] + ... ), + ... ) + + >>> agent_b = MessageFilterAgent( + ... name="B", + ... wrapped_agent=agent_b_inner, + ... filter=MessageFilterConfig( + ... per_source=[ + ... PerSourceFilter(source="user", position="first", count=1), + ... PerSourceFilter(source="A", position="last", count=1), + ... PerSourceFilter(source="B", position="last", count=10), + ... ] + ... ), + ... ) + + >>> agent_c = MessageFilterAgent( + ... name="C", + ... wrapped_agent=agent_c_inner, + ... filter=MessageFilterConfig( + ... per_source=[ + ... PerSourceFilter(source="user", position="first", count=1), + ... PerSourceFilter(source="B", position="last", count=1), + ... ] + ... ), + ... ) + + Then define the graph: + + >>> graph = DiGraph( + ... nodes={ + ... "A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]), + ... "B": DiGraphNode( + ... name="B", + ... edges=[ + ... DiGraphEdge(target="C", condition="exit"), + ... DiGraphEdge(target="A", condition="loop"), + ... ], + ... ), + ... "C": DiGraphNode(name="C", edges=[]), + ... }, + ... default_start_node="A", + ... ) + + This will ensure each agent sees only what is needed for its decision or action logic. + """ + + component_config_schema = MessageFilterAgentConfig + component_provider_override = "agentdhal_agentchat.agents.MessageFilterAgent" + + def __init__( + self, + name: str, + wrapped_agent: BaseChatAgent, + filter: MessageFilterConfig, + ): + super().__init__(name=name, description=f"{wrapped_agent.description} (with message filtering)") + self._wrapped_agent = wrapped_agent + self._filter = filter + + @property + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + return self._wrapped_agent.produced_message_types + + def _apply_filter(self, messages: Sequence[BaseChatMessage]) -> Sequence[BaseChatMessage]: + result: List[BaseChatMessage] = [] + + for source_filter in self._filter.per_source: + msgs = [m for m in messages if m.source == source_filter.source] + + if source_filter.position == "first" and source_filter.count: + msgs = msgs[: source_filter.count] + elif source_filter.position == "last" and source_filter.count: + msgs = msgs[-source_filter.count :] + + result.extend(msgs) + + return result + + async def on_messages( + self, + messages: Sequence[BaseChatMessage], + cancellation_token: CancellationToken, + ) -> Response: + filtered = self._apply_filter(messages) + return await self._wrapped_agent.on_messages(filtered, cancellation_token) + + async def on_messages_stream( + self, + messages: Sequence[BaseChatMessage], + cancellation_token: CancellationToken, + ) -> AsyncGenerator[Union[BaseAgentEvent, BaseChatMessage, Response], None]: + filtered = self._apply_filter(messages) + async for item in self._wrapped_agent.on_messages_stream(filtered, cancellation_token): + yield item + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + await self._wrapped_agent.on_reset(cancellation_token) + + def _to_config(self) -> MessageFilterAgentConfig: + return MessageFilterAgentConfig( + name=self.name, + wrapped_agent=self._wrapped_agent.dump_component(), + filter=self._filter, + ) + + @classmethod + def _from_config(cls, config: MessageFilterAgentConfig) -> "MessageFilterAgent": + wrapped = BaseChatAgent.load_component(config.wrapped_agent) + return cls( + name=config.name, + wrapped_agent=wrapped, + filter=config.filter, + ) diff --git a/agent_dhal/agentdhal_agentchat/agents/_society_of_mind_agent.py b/agent_dhal/agentdhal_agentchat/agents/_society_of_mind_agent.py new file mode 100644 index 0000000..ac1ae21 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/agents/_society_of_mind_agent.py @@ -0,0 +1,302 @@ +from typing import Any, AsyncGenerator, List, Mapping, Sequence + +from agentdhal_core import CancellationToken, Component, ComponentModel +from agentdhal_core.model_context import ( + ChatCompletionContext, + UnboundedChatCompletionContext, +) +from agentdhal_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage +from pydantic import BaseModel +from typing_extensions import Self + +from agentdhal_agentchat.base import Response +from agentdhal_agentchat.state import SocietyOfMindAgentState + +from ..base import TaskResult, Team +from ..messages import ( + BaseAgentEvent, + BaseChatMessage, + HandoffMessage, + ModelClientStreamingChunkEvent, + TextMessage, +) +from ._base_chat_agent import BaseChatAgent + + +class SocietyOfMindAgentConfig(BaseModel): + """The declarative configuration for a SocietyOfMindAgent.""" + + name: str + team: ComponentModel + model_client: ComponentModel + description: str | None = None + instruction: str | None = None + response_prompt: str | None = None + model_context: ComponentModel | None = None + + +class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]): + """An agent that uses an inner team of agents to generate responses. + + Each time the agent's :meth:`on_messages` or :meth:`on_messages_stream` + method is called, it runs the inner team of agents and then uses the + model client to generate a response based on the inner team's messages. + Once the response is generated, the agent resets the inner team by + calling :meth:`Team.reset`. + + Limit context size sent to the model: + + You can limit the number of messages sent to the model by setting + the `model_context` parameter to a :class:`~agentdhal_core.model_context.BufferedChatCompletionContext`. + This will limit the number of recent messages sent to the model and can be useful + when the model has a limit on the number of tokens it can process. + You can also create your own model context by subclassing + :class:`~agentdhal_core.model_context.ChatCompletionContext`. + + + Args: + name (str): The name of the agent. + team (Team): The team of agents to use. + model_client (ChatCompletionClient): The model client to use for preparing responses. + description (str, optional): The description of the agent. + instruction (str, optional): The instruction to use when generating a response using the inner team's messages. + Defaults to :attr:`DEFAULT_INSTRUCTION`. It assumes the role of 'system'. + response_prompt (str, optional): The response prompt to use when generating a response using the inner team's messages. + Defaults to :attr:`DEFAULT_RESPONSE_PROMPT`. It assumes the role of 'system'. + model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~agentdhal_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset. + + + + Example: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.ui import Console + from agentdhal_agentchat.agents import AssistantAgent, SocietyOfMindAgent + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_agentchat.conditions import TextMentionTermination + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a writer, write well.") + agent2 = AssistantAgent( + "assistant2", + model_client=model_client, + system_message="You are an editor, provide critical feedback. Respond with 'APPROVE' if the text addresses all feedbacks.", + ) + inner_termination = TextMentionTermination("APPROVE") + inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination) + + society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client) + + agent3 = AssistantAgent( + "assistant3", model_client=model_client, system_message="Translate the text to Spanish." + ) + team = RoundRobinGroupChat([society_of_mind_agent, agent3], max_turns=2) + + stream = team.run_stream(task="Write a short story with a surprising ending.") + await Console(stream) + + + asyncio.run(main()) + """ + + component_config_schema = SocietyOfMindAgentConfig + component_provider_override = "agentdhal_agentchat.agents.SocietyOfMindAgent" + + DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:" + """str: The default instruction to use when generating a response using the + inner team's messages. The instruction will be prepended to the inner team's + messages when generating a response using the model. It assumes the role of + 'system'.""" + + DEFAULT_RESPONSE_PROMPT = ( + "Output a standalone response to the original request, without mentioning any of the intermediate discussion." + ) + """str: The default response prompt to use when generating a response using + the inner team's messages. It assumes the role of 'system'.""" + + DEFAULT_DESCRIPTION = "An agent that uses an inner team of agents to generate responses." + """str: The default description for a SocietyOfMindAgent.""" + + def __init__( + self, + name: str, + team: Team, + model_client: ChatCompletionClient, + *, + description: str = DEFAULT_DESCRIPTION, + instruction: str = DEFAULT_INSTRUCTION, + response_prompt: str = DEFAULT_RESPONSE_PROMPT, + model_context: ChatCompletionContext | None = None, + ) -> None: + super().__init__(name=name, description=description) + self._team = team + self._model_client = model_client + self._instruction = instruction + self._response_prompt = response_prompt + + if model_context is not None: + self._model_context = model_context + else: + self._model_context = UnboundedChatCompletionContext() + + @property + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + return (TextMessage,) + + @property + def model_context(self) -> ChatCompletionContext: + """ + The model context in use by the agent. + """ + return self._model_context + + async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: + # Call the stream method and collect the messages. + response: Response | None = None + async for msg in self.on_messages_stream(messages, cancellation_token): + if isinstance(msg, Response): + response = msg + assert response is not None + return response + + async def on_messages_stream( + self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]: + # Prepare the task for the team of agents. + task_messages = list(messages) + + # Run the team of agents. + result: TaskResult | None = None + inner_messages: List[BaseAgentEvent | BaseChatMessage] = [] + model_context = self._model_context + + prev_content = await model_context.get_messages() + if len(prev_content) > 0: + prev_message = HandoffMessage( + content="relevant previous messages", + source=self.name, + target="", + context=prev_content, + ) + task_messages = [prev_message] + task_messages + + if len(task_messages) == 0: + task = None + else: + task = task_messages + + # Use the new output_task_messages parameter to avoid fragile count-based logic + async for inner_msg in self._team.run_stream( + task=task, cancellation_token=cancellation_token, output_task_messages=False + ): + if isinstance(inner_msg, TaskResult): + result = inner_msg + else: + yield inner_msg + if isinstance(inner_msg, ModelClientStreamingChunkEvent): + # Skip the model client streaming chunk events. + continue + inner_messages.append(inner_msg) + assert result is not None + + if len(inner_messages) == 0: + yield Response( + chat_message=TextMessage(source=self.name, content="No response."), + inner_messages=[], + # Response's inner_messages should be empty. Cause that mean is response to outer world. + ) + else: + llm_messages: List[LLMMessage] = [] + + if self._model_client.model_info.get("multiple_system_messages", False): + # The model client supports multiple system messages, so we + llm_messages.append(SystemMessage(content=self._instruction)) + else: + # The model client does not support multiple system messages, so we + llm_messages.append(UserMessage(content=self._instruction, source="user")) + + # Generate a response using the model client. + for message in inner_messages: + if isinstance(message, BaseChatMessage): + llm_messages.append(message.to_model_message()) + + if self._model_client.model_info.get("multiple_system_messages", False): + # The model client supports multiple system messages, so we + llm_messages.append(SystemMessage(content=self._response_prompt)) + else: + # The model client does not support multiple system messages, so we + llm_messages.append(UserMessage(content=self._response_prompt, source="user")) + completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token) + assert isinstance(completion.content, str) + yield Response( + chat_message=TextMessage(source=self.name, content=completion.content, models_usage=completion.usage), + inner_messages=[], + # Response's inner_messages should be empty. Cause that mean is response to outer world. + ) + + # Add new user/handoff messages to the model context + await self._add_messages_to_context( + model_context=model_context, + messages=messages, + ) + + # Reset the team. + await self._team.reset() + + @staticmethod + async def _add_messages_to_context( + model_context: ChatCompletionContext, + messages: Sequence[BaseChatMessage], + ) -> None: + """ + Add incoming messages to the model context. + """ + for msg in messages: + if isinstance(msg, HandoffMessage): + for llm_msg in msg.context: + await model_context.add_message(llm_msg) + await model_context.add_message(msg.to_model_message()) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + await self._team.reset() + await self._model_context.clear() + + async def save_state(self) -> Mapping[str, Any]: + team_state = await self._team.save_state() + state = SocietyOfMindAgentState(inner_team_state=team_state) + return state.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + society_of_mind_state = SocietyOfMindAgentState.model_validate(state) + await self._team.load_state(society_of_mind_state.inner_team_state) + + def _to_config(self) -> SocietyOfMindAgentConfig: + return SocietyOfMindAgentConfig( + name=self.name, + team=self._team.dump_component(), + model_client=self._model_client.dump_component(), + description=self.description, + instruction=self._instruction, + response_prompt=self._response_prompt, + model_context=self._model_context.dump_component(), + ) + + @classmethod + def _from_config(cls, config: SocietyOfMindAgentConfig) -> Self: + model_client = ChatCompletionClient.load_component(config.model_client) + team = Team.load_component(config.team) + return cls( + name=config.name, + team=team, + model_client=model_client, + description=config.description or cls.DEFAULT_DESCRIPTION, + instruction=config.instruction or cls.DEFAULT_INSTRUCTION, + response_prompt=config.response_prompt or cls.DEFAULT_RESPONSE_PROMPT, + model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, + ) diff --git a/agent_dhal/agentdhal_agentchat/agents/_user_proxy_agent.py b/agent_dhal/agentdhal_agentchat/agents/_user_proxy_agent.py new file mode 100644 index 0000000..a1c5e9d --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/agents/_user_proxy_agent.py @@ -0,0 +1,249 @@ +import asyncio +import uuid +from contextlib import contextmanager +from contextvars import ContextVar +from inspect import iscoroutinefunction +from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast + +from agentdhal_core import CancellationToken, Component +from pydantic import BaseModel +from typing_extensions import Self + +from ..base import Response +from ..messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent +from ._base_chat_agent import BaseChatAgent + +SyncInputFunc = Callable[[str], str] +AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]] +InputFuncType = Union[SyncInputFunc, AsyncInputFunc] + + +# TODO: check if using to_thread fixes this in jupyter +async def cancellable_input(prompt: str, cancellation_token: Optional[CancellationToken]) -> str: + task: asyncio.Task[str] = asyncio.create_task(asyncio.to_thread(input, prompt)) + if cancellation_token is not None: + cancellation_token.link_future(task) + return await task + + +class UserProxyAgentConfig(BaseModel): + """Declarative configuration for the UserProxyAgent.""" + + name: str + description: str = "A human user" + input_func: str | None = None + + +class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]): + """An agent that can represent a human user through an input function. + + This agent can be used to represent a human user in a chat system by providing a custom input function. + + .. note:: + + Using :class:`UserProxyAgent` puts a running team in a temporary blocked + state until the user responds. So it is important to time out the user input + function and cancel using the :class:`~agentdhal_core.CancellationToken` if the user does not respond. + The input function should also handle exceptions and return a default response if needed. + + For typical use cases that involve + slow human responses, it is recommended to use termination conditions + such as :class:`~agentdhal_agentchat.conditions.HandoffTermination` or :class:`~agentdhal_agentchat.conditions.SourceMatchTermination` + to stop the running team and return the control to the application. + You can run the team again with the user input. This way, the state of the team + can be saved and restored when the user responds. + + See `Human-in-the-loop `_ for more information. + + Args: + name (str): The name of the agent. + description (str, optional): A description of the agent. + input_func (Optional[Callable[[str], str]], Callable[[str, Optional[CancellationToken]], Awaitable[str]]): A function that takes a prompt and returns a user input string. + + For examples of integrating with web and UI frameworks, see the following: + + * `FastAPI `_ + * `ChainLit `_ + + Example: + Simple usage case:: + + import asyncio + from agentdhal_core import CancellationToken + from agentdhal_agentchat.agents import UserProxyAgent + from agentdhal_agentchat.messages import TextMessage + + + async def simple_user_agent(): + agent = UserProxyAgent("user_proxy") + response = await asyncio.create_task( + agent.on_messages( + [TextMessage(content="What is your name? ", source="user")], + cancellation_token=CancellationToken(), + ) + ) + assert isinstance(response.chat_message, TextMessage) + print(f"Your name is {response.chat_message.content}") + + Example: + Cancellable usage case:: + + import asyncio + from typing import Any + from agentdhal_core import CancellationToken + from agentdhal_agentchat.agents import UserProxyAgent + from agentdhal_agentchat.messages import TextMessage + + + token = CancellationToken() + agent = UserProxyAgent("user_proxy") + + + async def timeout(delay: float): + await asyncio.sleep(delay) + + + def cancellation_callback(task: asyncio.Task[Any]): + token.cancel() + + + async def cancellable_user_agent(): + try: + timeout_task = asyncio.create_task(timeout(3)) + timeout_task.add_done_callback(cancellation_callback) + agent_task = asyncio.create_task( + agent.on_messages( + [TextMessage(content="What is your name? ", source="user")], + cancellation_token=token, + ) + ) + response = await agent_task + assert isinstance(response.chat_message, TextMessage) + print(f"Your name is {response.chat_message.content}") + except Exception as e: + print(f"Exception: {e}") + except BaseException as e: + print(f"BaseException: {e}") + """ + + component_type = "agent" + component_provider_override = "agentdhal_agentchat.agents.UserProxyAgent" + component_config_schema = UserProxyAgentConfig + + class InputRequestContext: + def __init__(self) -> None: + raise RuntimeError( + "InputRequestContext cannot be instantiated. It is a static class that provides context management for user input requests." + ) + + _INPUT_REQUEST_CONTEXT_VAR: ClassVar[ContextVar[str]] = ContextVar("_INPUT_REQUEST_CONTEXT_VAR") + + @classmethod + @contextmanager + def populate_context(cls, ctx: str) -> Generator[None, Any, None]: + """:meta private:""" + token = UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.set(ctx) + try: + yield + finally: + UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.reset(token) + + @classmethod + def request_id(cls) -> str: + try: + return cls._INPUT_REQUEST_CONTEXT_VAR.get() + except LookupError as e: + raise RuntimeError( + "InputRequestContext.runtime() must be called within the input callback of a UserProxyAgent." + ) from e + + def __init__( + self, + name: str, + *, + description: str = "A human user", + input_func: Optional[InputFuncType] = None, + ) -> None: + """Initialize the UserProxyAgent.""" + super().__init__(name=name, description=description) + self.input_func = input_func or cancellable_input + self._is_async = iscoroutinefunction(self.input_func) + + @property + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + """Message types this agent can produce.""" + return (TextMessage, HandoffMessage) + + def _get_latest_handoff(self, messages: Sequence[BaseChatMessage]) -> Optional[HandoffMessage]: + """Find the HandoffMessage in the message sequence that addresses this agent.""" + if len(messages) > 0 and isinstance(messages[-1], HandoffMessage): + if messages[-1].target == self.name: + return messages[-1] + else: + raise RuntimeError(f"Handoff message target does not match agent name: {messages[-1].source}") + return None + + async def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str: + """Handle input based on function signature.""" + try: + if self._is_async: + # Cast to AsyncInputFunc for proper typing + async_func = cast(AsyncInputFunc, self.input_func) + return await async_func(prompt, cancellation_token) + else: + # Cast to SyncInputFunc for proper typing + sync_func = cast(SyncInputFunc, self.input_func) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, sync_func, prompt) + + except asyncio.CancelledError: + raise + except Exception as e: + raise RuntimeError(f"Failed to get user input: {str(e)}") from e + + async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: + async for message in self.on_messages_stream(messages, cancellation_token): + if isinstance(message, Response): + return message + raise AssertionError("The stream should have returned the final result.") + + async def on_messages_stream( + self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]: + """Handle incoming messages by requesting user input.""" + try: + # Check for handoff first + handoff = self._get_latest_handoff(messages) + prompt = ( + f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: " + ) + + request_id = str(uuid.uuid4()) + + input_requested_event = UserInputRequestedEvent(request_id=request_id, source=self.name) + yield input_requested_event + with UserProxyAgent.InputRequestContext.populate_context(request_id): + user_input = await self._get_input(prompt, cancellation_token) + + # Return appropriate message type based on handoff presence + if handoff: + yield Response(chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)) + else: + yield Response(chat_message=TextMessage(content=user_input, source=self.name)) + + except asyncio.CancelledError: + raise + except Exception as e: + raise RuntimeError(f"Failed to get user input: {str(e)}") from e + + async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None: + """Reset agent state.""" + pass + + def _to_config(self) -> UserProxyAgentConfig: + # TODO: Add ability to serialie input_func + return UserProxyAgentConfig(name=self.name, description=self.description, input_func=None) + + @classmethod + def _from_config(cls, config: UserProxyAgentConfig) -> Self: + return cls(name=config.name, description=config.description, input_func=None) diff --git a/agent_dhal/agentdhal_agentchat/base/__init__.py b/agent_dhal/agentdhal_agentchat/base/__init__.py new file mode 100644 index 0000000..dcb0a24 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/base/__init__.py @@ -0,0 +1,18 @@ +from ._chat_agent import ChatAgent, Response +from ._handoff import Handoff +from ._task import TaskResult, TaskRunner +from ._team import Team +from ._termination import AndTerminationCondition, OrTerminationCondition, TerminatedException, TerminationCondition + +__all__ = [ + "ChatAgent", + "Response", + "Team", + "TerminatedException", + "TerminationCondition", + "AndTerminationCondition", + "OrTerminationCondition", + "TaskResult", + "TaskRunner", + "Handoff", +] diff --git a/agent_dhal/agentdhal_agentchat/base/_chat_agent.py b/agent_dhal/agentdhal_agentchat/base/_chat_agent.py new file mode 100644 index 0000000..5c7138c --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/base/_chat_agent.py @@ -0,0 +1,94 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Mapping, Sequence + +from agentdhal_core import CancellationToken, ComponentBase +from pydantic import BaseModel, SerializeAsAny + +from ..messages import BaseAgentEvent, BaseChatMessage +from ._task import TaskRunner + + +@dataclass(kw_only=True) +class Response: + """A response from calling :meth:`ChatAgent.on_messages`.""" + + chat_message: SerializeAsAny[BaseChatMessage] + """A chat message produced by the agent as the response.""" + + inner_messages: Sequence[SerializeAsAny[BaseAgentEvent | BaseChatMessage]] | None = None + """Inner messages produced by the agent, they can be :class:`BaseAgentEvent` + or :class:`BaseChatMessage`.""" + + +class ChatAgent(ABC, TaskRunner, ComponentBase[BaseModel]): + """Protocol for a chat agent.""" + + component_type = "agent" + + @property + @abstractmethod + def name(self) -> str: + """The name of the agent. This is used by team to uniquely identify + the agent. It should be unique within the team.""" + ... + + @property + @abstractmethod + def description(self) -> str: + """The description of the agent. This is used by team to + make decisions about which agents to use. The description should + describe the agent's capabilities and how to interact with it.""" + ... + + @property + @abstractmethod + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + """The types of messages that the agent produces in the + :attr:`Response.chat_message` field. They must be :class:`BaseChatMessage` types.""" + ... + + @abstractmethod + async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: + """Handles incoming messages and returns a response.""" + ... + + @abstractmethod + def on_messages_stream( + self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]: + """Handles incoming messages and returns a stream of inner messages and + and the final item is the response.""" + ... + + @abstractmethod + async def on_reset(self, cancellation_token: CancellationToken) -> None: + """Resets the agent to its initialization state.""" + ... + + @abstractmethod + async def on_pause(self, cancellation_token: CancellationToken) -> None: + """Called when the agent is paused. The agent may be running in :meth:`on_messages` or + :meth:`on_messages_stream` when this method is called.""" + ... + + @abstractmethod + async def on_resume(self, cancellation_token: CancellationToken) -> None: + """Called when the agent is resumed. The agent may be running in :meth:`on_messages` or + :meth:`on_messages_stream` when this method is called.""" + ... + + @abstractmethod + async def save_state(self) -> Mapping[str, Any]: + """Save agent state for later restoration""" + ... + + @abstractmethod + async def load_state(self, state: Mapping[str, Any]) -> None: + """Restore agent from saved state""" + ... + + @abstractmethod + async def close(self) -> None: + """Release any resources held by the agent.""" + ... diff --git a/agent_dhal/agentdhal_agentchat/base/_handoff.py b/agent_dhal/agentdhal_agentchat/base/_handoff.py new file mode 100644 index 0000000..bf4eeb2 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/base/_handoff.py @@ -0,0 +1,62 @@ +import logging +from typing import Any, Dict + +from agentdhal_core.tools import BaseTool, FunctionTool +from pydantic import BaseModel, Field, model_validator + +from .. import EVENT_LOGGER_NAME + +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + + +class Handoff(BaseModel): + """Handoff configuration.""" + + target: str + """The name of the target agent to handoff to.""" + + description: str = Field(default="") + """The description of the handoff such as the condition under which it should happen and the target agent's ability. + If not provided, it is generated from the target agent's name.""" + + name: str = Field(default="") + """The name of this handoff configuration. If not provided, it is generated from the target agent's name.""" + + message: str = Field(default="") + """The message to the target agent. + By default, it will be the result for the handoff tool. + If not provided, it is generated from the target agent's name.""" + + @model_validator(mode="before") + @classmethod + def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if not values.get("description"): + values["description"] = f"Handoff to {values['target']}." + if not values.get("name"): + values["name"] = f"transfer_to_{values['target']}".lower() + else: + name = values["name"] + if not isinstance(name, str): + raise ValueError(f"Handoff name must be a string: {values['name']}") + # Check if name is a valid identifier. + if not name.isidentifier(): + raise ValueError(f"Handoff name must be a valid identifier: {values['name']}") + if not values.get("message"): + values["message"] = ( + f"Transferred to {values['target']}, adopting the role of {values['target']} immediately." + ) + return values + + @property + def handoff_tool(self) -> BaseTool[BaseModel, BaseModel]: + """Create a handoff tool from this handoff configuration.""" + + def _handoff_tool() -> str: + return self.message + + return FunctionTool(_handoff_tool, name=self.name, description=self.description, strict=True) + + """ + The tool that can be used to handoff to the target agent. + Typically, the results of the tool's execution are provided to the target agent. + """ diff --git a/agent_dhal/agentdhal_agentchat/base/_task.py b/agent_dhal/agentdhal_agentchat/base/_task.py new file mode 100644 index 0000000..3bb0be5 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/base/_task.py @@ -0,0 +1,65 @@ +from typing import AsyncGenerator, Protocol, Sequence + +from agentdhal_core import CancellationToken +from pydantic import BaseModel, SerializeAsAny + +from ..messages import BaseAgentEvent, BaseChatMessage + + +class TaskResult(BaseModel): + """Result of running a task.""" + + messages: Sequence[SerializeAsAny[BaseAgentEvent | BaseChatMessage]] + """Messages produced by the task.""" + + stop_reason: str | None = None + """The reason the task stopped.""" + + +class TaskRunner(Protocol): + """A task runner.""" + + async def run( + self, + *, + task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None, + cancellation_token: CancellationToken | None = None, + output_task_messages: bool = True, + ) -> TaskResult: + """Run the task and return the result. + + The task can be a string, a single message, or a sequence of messages. + + The runner is stateful and a subsequent call to this method will continue + from where the previous call left off. If the task is not specified, + the runner will continue with the current task. + + Args: + task: The task to run. Can be a string, a single message, or a sequence of messages. + cancellation_token: The cancellation token to kill the task immediately. + output_task_messages: Whether to include task messages in :attr:`TaskResult.messages`. Defaults to True for backward compatibility. + """ + ... + + def run_stream( + self, + *, + task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None, + cancellation_token: CancellationToken | None = None, + output_task_messages: bool = True, + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]: + """Run the task and produces a stream of messages and the final result + :class:`TaskResult` as the last item in the stream. + + The task can be a string, a single message, or a sequence of messages. + + The runner is stateful and a subsequent call to this method will continue + from where the previous call left off. If the task is not specified, + the runner will continue with the current task. + + Args: + task: The task to run. Can be a string, a single message, or a sequence of messages. + cancellation_token: The cancellation token to kill the task immediately. + output_task_messages: Whether to include task messages in the output stream. Defaults to True for backward compatibility. + """ + ... diff --git a/agent_dhal/agentdhal_agentchat/base/_team.py b/agent_dhal/agentdhal_agentchat/base/_team.py new file mode 100644 index 0000000..67ce416 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/base/_team.py @@ -0,0 +1,54 @@ +from abc import ABC, abstractmethod +from typing import Any, Mapping + +from agentdhal_core import ComponentBase +from pydantic import BaseModel + +from ._task import TaskRunner + + +class Team(ABC, TaskRunner, ComponentBase[BaseModel]): + component_type = "team" + + @property + @abstractmethod + def name(self) -> str: + """The name of the team. This is used by team to uniquely identify itself + in a larger team of teams.""" + ... + + @property + @abstractmethod + def description(self) -> str: + """A description of the team. This is used to provide context about the + team and its purpose to its parent orchestrator.""" + ... + + @abstractmethod + async def reset(self) -> None: + """Reset the team and all its participants to its initial state.""" + ... + + @abstractmethod + async def pause(self) -> None: + """Pause the team and all its participants. This is useful for + pausing the :meth:`agentdhal_agentchat.base.TaskRunner.run` or + :meth:`agentdhal_agentchat.base.TaskRunner.run_stream` methods from + concurrently, while keeping them alive.""" + ... + + @abstractmethod + async def resume(self) -> None: + """Resume the team and all its participants from a pause after + :meth:`pause` was called.""" + ... + + @abstractmethod + async def save_state(self) -> Mapping[str, Any]: + """Save the current state of the team.""" + ... + + @abstractmethod + async def load_state(self, state: Mapping[str, Any]) -> None: + """Load the state of the team.""" + ... diff --git a/agent_dhal/agentdhal_agentchat/base/_termination.py b/agent_dhal/agentdhal_agentchat/base/_termination.py new file mode 100644 index 0000000..d4471ff --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/base/_termination.py @@ -0,0 +1,179 @@ +import asyncio +from abc import ABC, abstractmethod +from typing import List, Sequence + +from agentdhal_core import Component, ComponentBase, ComponentModel +from pydantic import BaseModel +from typing_extensions import Self + +from ..messages import BaseAgentEvent, BaseChatMessage, StopMessage + + +class TerminatedException(BaseException): ... + + +class TerminationCondition(ABC, ComponentBase[BaseModel]): + """A stateful condition that determines when a conversation should be terminated. + + A termination condition is a callable that takes a sequence of BaseChatMessage objects + since the last time the condition was called, and returns a StopMessage if the + conversation should be terminated, or None otherwise. + Once a termination condition has been reached, it must be reset before it can be used again. + + Termination conditions can be combined using the AND and OR operators. + + Example: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.conditions import MaxMessageTermination, TextMentionTermination + + + async def main() -> None: + # Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned. + cond1 = MaxMessageTermination(10) | TextMentionTermination("TERMINATE") + + # Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned. + cond2 = MaxMessageTermination(10) & TextMentionTermination("TERMINATE") + + # ... + + # Reset the termination condition. + await cond1.reset() + await cond2.reset() + + + asyncio.run(main()) + """ + + component_type = "termination" + + @property + @abstractmethod + def terminated(self) -> bool: + """Check if the termination condition has been reached""" + ... + + @abstractmethod + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + """Check if the conversation should be terminated based on the messages received + since the last time the condition was called. + Return a StopMessage if the conversation should be terminated, or None otherwise. + + Args: + messages: The messages received since the last time the condition was called. + + Returns: + StopMessage | None: A StopMessage if the conversation should be terminated, or None otherwise. + + Raises: + TerminatedException: If the termination condition has already been reached.""" + ... + + @abstractmethod + async def reset(self) -> None: + """Reset the termination condition.""" + ... + + def __and__(self, other: "TerminationCondition") -> "TerminationCondition": + """Combine two termination conditions with an AND operation.""" + return AndTerminationCondition(self, other) + + def __or__(self, other: "TerminationCondition") -> "TerminationCondition": + """Combine two termination conditions with an OR operation.""" + return OrTerminationCondition(self, other) + + +class AndTerminationConditionConfig(BaseModel): + conditions: List[ComponentModel] + + +class AndTerminationCondition(TerminationCondition, Component[AndTerminationConditionConfig]): + component_config_schema = AndTerminationConditionConfig + component_type = "termination" + component_provider_override = "agentdhal_agentchat.base.AndTerminationCondition" + + def __init__(self, *conditions: TerminationCondition) -> None: + self._conditions = conditions + self._stop_messages: List[StopMessage] = [] + + @property + def terminated(self) -> bool: + return all(condition.terminated for condition in self._conditions) + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self.terminated: + raise TerminatedException("Termination condition has already been reached.") + # Check all remaining conditions. + stop_messages = await asyncio.gather( + *[condition(messages) for condition in self._conditions if not condition.terminated] + ) + # Collect stop messages. + for stop_message in stop_messages: + if stop_message is not None: + self._stop_messages.append(stop_message) + if any(stop_message is None for stop_message in stop_messages): + # If any remaining condition has not reached termination, it is not terminated. + return None + content = ", ".join(stop_message.content for stop_message in self._stop_messages) + source = ", ".join(stop_message.source for stop_message in self._stop_messages) + return StopMessage(content=content, source=source) + + async def reset(self) -> None: + for condition in self._conditions: + await condition.reset() + self._stop_messages.clear() + + def _to_config(self) -> AndTerminationConditionConfig: + """Convert the AND termination condition to a config.""" + return AndTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions]) + + @classmethod + def _from_config(cls, config: AndTerminationConditionConfig) -> Self: + """Create an AND termination condition from a config.""" + conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions] + return cls(*conditions) + + +class OrTerminationConditionConfig(BaseModel): + conditions: List[ComponentModel] + """List of termination conditions where any one being satisfied is sufficient.""" + + +class OrTerminationCondition(TerminationCondition, Component[OrTerminationConditionConfig]): + component_config_schema = OrTerminationConditionConfig + component_type = "termination" + component_provider_override = "agentdhal_agentchat.base.OrTerminationCondition" + + def __init__(self, *conditions: TerminationCondition) -> None: + self._conditions = conditions + + @property + def terminated(self) -> bool: + return any(condition.terminated for condition in self._conditions) + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self.terminated: + raise RuntimeError("Termination condition has already been reached") + stop_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions]) + stop_messages_filter = [stop_message for stop_message in stop_messages if stop_message is not None] + if len(stop_messages_filter) > 0: + content = ", ".join(stop_message.content for stop_message in stop_messages_filter) + source = ", ".join(stop_message.source for stop_message in stop_messages_filter) + return StopMessage(content=content, source=source) + return None + + async def reset(self) -> None: + for condition in self._conditions: + await condition.reset() + + def _to_config(self) -> OrTerminationConditionConfig: + """Convert the OR termination condition to a config.""" + return OrTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions]) + + @classmethod + def _from_config(cls, config: OrTerminationConditionConfig) -> Self: + """Create an OR termination condition from a config.""" + conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions] + return cls(*conditions) diff --git a/agent_dhal/agentdhal_agentchat/conditions/__init__.py b/agent_dhal/agentdhal_agentchat/conditions/__init__.py new file mode 100644 index 0000000..72b6174 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/conditions/__init__.py @@ -0,0 +1,32 @@ +""" +This module provides various termination conditions for controlling the behavior of +multi-agent teams. +""" + +from ._terminations import ( + ExternalTermination, + FunctionalTermination, + FunctionCallTermination, + HandoffTermination, + MaxMessageTermination, + SourceMatchTermination, + StopMessageTermination, + TextMentionTermination, + TextMessageTermination, + TimeoutTermination, + TokenUsageTermination, +) + +__all__ = [ + "MaxMessageTermination", + "TextMentionTermination", + "StopMessageTermination", + "TokenUsageTermination", + "HandoffTermination", + "TimeoutTermination", + "ExternalTermination", + "SourceMatchTermination", + "TextMessageTermination", + "FunctionCallTermination", + "FunctionalTermination", +] diff --git a/agent_dhal/agentdhal_agentchat/conditions/_terminations.py b/agent_dhal/agentdhal_agentchat/conditions/_terminations.py new file mode 100644 index 0000000..4cdeb76 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/conditions/_terminations.py @@ -0,0 +1,614 @@ +import asyncio +import time +from typing import Awaitable, Callable, List, Sequence + +from agentdhal_core import Component +from pydantic import BaseModel +from typing_extensions import Self + +from ..base import TerminatedException, TerminationCondition +from ..messages import ( + BaseAgentEvent, + BaseChatMessage, + HandoffMessage, + StopMessage, + TextMessage, + ToolCallExecutionEvent, +) + + +class StopMessageTerminationConfig(BaseModel): + pass + + +class StopMessageTermination(TerminationCondition, Component[StopMessageTerminationConfig]): + """Terminate the conversation if a StopMessage is received.""" + + component_config_schema = StopMessageTerminationConfig + component_provider_override = "agentdhal_agentchat.conditions.StopMessageTermination" + + def __init__(self) -> None: + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + for message in messages: + if isinstance(message, StopMessage): + self._terminated = True + return StopMessage(content="Stop message received", source="StopMessageTermination") + return None + + async def reset(self) -> None: + self._terminated = False + + def _to_config(self) -> StopMessageTerminationConfig: + return StopMessageTerminationConfig() + + @classmethod + def _from_config(cls, config: StopMessageTerminationConfig) -> Self: + return cls() + + +class MaxMessageTerminationConfig(BaseModel): + max_messages: int + include_agent_event: bool = False + + +class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminationConfig]): + """Terminate the conversation after a maximum number of messages have been exchanged. + + Args: + max_messages: The maximum number of messages allowed in the conversation. + include_agent_event: If True, include :class:`~agentdhal_agentchat.messages.BaseAgentEvent` in the message count. + Otherwise, only include :class:`~agentdhal_agentchat.messages.BaseChatMessage`. Defaults to False. + """ + + component_config_schema = MaxMessageTerminationConfig + component_provider_override = "agentdhal_agentchat.conditions.MaxMessageTermination" + + def __init__(self, max_messages: int, include_agent_event: bool = False) -> None: + self._max_messages = max_messages + self._message_count = 0 + self._include_agent_event = include_agent_event + + @property + def terminated(self) -> bool: + return self._message_count >= self._max_messages + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self.terminated: + raise TerminatedException("Termination condition has already been reached") + self._message_count += len([m for m in messages if self._include_agent_event or isinstance(m, BaseChatMessage)]) + if self._message_count >= self._max_messages: + return StopMessage( + content=f"Maximum number of messages {self._max_messages} reached, current message count: {self._message_count}", + source="MaxMessageTermination", + ) + return None + + async def reset(self) -> None: + self._message_count = 0 + + def _to_config(self) -> MaxMessageTerminationConfig: + return MaxMessageTerminationConfig( + max_messages=self._max_messages, include_agent_event=self._include_agent_event + ) + + @classmethod + def _from_config(cls, config: MaxMessageTerminationConfig) -> Self: + return cls(max_messages=config.max_messages, include_agent_event=config.include_agent_event) + + +class TextMentionTerminationConfig(BaseModel): + text: str + + +class TextMentionTermination(TerminationCondition, Component[TextMentionTerminationConfig]): + """Terminate the conversation if a specific text is mentioned. + + + Args: + text: The text to look for in the messages. + sources: Check only messages of the specified agents for the text to look for. + """ + + component_config_schema = TextMentionTerminationConfig + component_provider_override = "agentdhal_agentchat.conditions.TextMentionTermination" + + def __init__(self, text: str, sources: Sequence[str] | None = None) -> None: + self._termination_text = text + self._terminated = False + self._sources = sources + + @property + def terminated(self) -> bool: + return self._terminated + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + for message in messages: + if self._sources is not None and message.source not in self._sources: + continue + + content = message.to_text() + if self._termination_text in content: + self._terminated = True + return StopMessage( + content=f"Text '{self._termination_text}' mentioned", source="TextMentionTermination" + ) + return None + + async def reset(self) -> None: + self._terminated = False + + def _to_config(self) -> TextMentionTerminationConfig: + return TextMentionTerminationConfig(text=self._termination_text) + + @classmethod + def _from_config(cls, config: TextMentionTerminationConfig) -> Self: + return cls(text=config.text) + + +class FunctionalTermination(TerminationCondition): + """Terminate the conversation if an functional expression is met. + + Args: + func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], bool] | Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[bool]]): A function that takes a sequence of messages + and returns True if the termination condition is met, False otherwise. + The function can be a callable or an async callable. + + Example: + + .. code-block:: python + + import asyncio + from typing import Sequence + + from agentdhal_agentchat.conditions import FunctionalTermination + from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, StopMessage + + + def expression(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> bool: + # Check if the last message is a stop message + return isinstance(messages[-1], StopMessage) + + + termination = FunctionalTermination(expression) + + + async def run() -> None: + messages = [ + StopMessage(source="agent1", content="Stop"), + ] + result = await termination(messages) + print(result) + + + asyncio.run(run()) + + .. code-block:: text + + StopMessage(source="FunctionalTermination", content="Functional termination condition met") + + """ + + def __init__( + self, + func: Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], bool] + | Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[bool]], + ) -> None: + self._func = func + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + if asyncio.iscoroutinefunction(self._func): + result = await self._func(messages) + else: + result = self._func(messages) + if result is True: + self._terminated = True + return StopMessage(content="Functional termination condition met", source="FunctionalTermination") + return None + + async def reset(self) -> None: + self._terminated = False + + +class TokenUsageTerminationConfig(BaseModel): + max_total_token: int | None + max_prompt_token: int | None + max_completion_token: int | None + + +class TokenUsageTermination(TerminationCondition, Component[TokenUsageTerminationConfig]): + """Terminate the conversation if a token usage limit is reached. + + Args: + max_total_token: The maximum total number of tokens allowed in the conversation. + max_prompt_token: The maximum number of prompt tokens allowed in the conversation. + max_completion_token: The maximum number of completion tokens allowed in the conversation. + + Raises: + ValueError: If none of max_total_token, max_prompt_token, or max_completion_token is provided. + """ + + component_config_schema = TokenUsageTerminationConfig + component_provider_override = "agentdhal_agentchat.conditions.TokenUsageTermination" + + def __init__( + self, + max_total_token: int | None = None, + max_prompt_token: int | None = None, + max_completion_token: int | None = None, + ) -> None: + if max_total_token is None and max_prompt_token is None and max_completion_token is None: + raise ValueError( + "At least one of max_total_token, max_prompt_token, or max_completion_token must be provided" + ) + self._max_total_token = max_total_token + self._max_prompt_token = max_prompt_token + self._max_completion_token = max_completion_token + self._total_token_count = 0 + self._prompt_token_count = 0 + self._completion_token_count = 0 + + @property + def terminated(self) -> bool: + return ( + (self._max_total_token is not None and self._total_token_count >= self._max_total_token) + or (self._max_prompt_token is not None and self._prompt_token_count >= self._max_prompt_token) + or (self._max_completion_token is not None and self._completion_token_count >= self._max_completion_token) + ) + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self.terminated: + raise TerminatedException("Termination condition has already been reached") + for message in messages: + if message.models_usage is not None: + self._prompt_token_count += message.models_usage.prompt_tokens + self._completion_token_count += message.models_usage.completion_tokens + self._total_token_count += message.models_usage.prompt_tokens + message.models_usage.completion_tokens + if self.terminated: + content = f"Token usage limit reached, total token count: {self._total_token_count}, prompt token count: {self._prompt_token_count}, completion token count: {self._completion_token_count}." + return StopMessage(content=content, source="TokenUsageTermination") + return None + + async def reset(self) -> None: + self._total_token_count = 0 + self._prompt_token_count = 0 + self._completion_token_count = 0 + + def _to_config(self) -> TokenUsageTerminationConfig: + return TokenUsageTerminationConfig( + max_total_token=self._max_total_token, + max_prompt_token=self._max_prompt_token, + max_completion_token=self._max_completion_token, + ) + + @classmethod + def _from_config(cls, config: TokenUsageTerminationConfig) -> Self: + return cls( + max_total_token=config.max_total_token, + max_prompt_token=config.max_prompt_token, + max_completion_token=config.max_completion_token, + ) + + +class HandoffTerminationConfig(BaseModel): + target: str + + +class HandoffTermination(TerminationCondition, Component[HandoffTerminationConfig]): + """Terminate the conversation if a :class:`~agentdhal_agentchat.messages.HandoffMessage` + with the given target is received. + + Args: + target (str): The target of the handoff message. + """ + + component_config_schema = HandoffTerminationConfig + component_provider_override = "agentdhal_agentchat.conditions.HandoffTermination" + + def __init__(self, target: str) -> None: + self._terminated = False + self._target = target + + @property + def terminated(self) -> bool: + return self._terminated + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + for message in messages: + if isinstance(message, HandoffMessage) and message.target == self._target: + self._terminated = True + return StopMessage( + content=f"Handoff to {self._target} from {message.source} detected.", source="HandoffTermination" + ) + return None + + async def reset(self) -> None: + self._terminated = False + + def _to_config(self) -> HandoffTerminationConfig: + return HandoffTerminationConfig(target=self._target) + + @classmethod + def _from_config(cls, config: HandoffTerminationConfig) -> Self: + return cls(target=config.target) + + +class TimeoutTerminationConfig(BaseModel): + timeout_seconds: float + + +class TimeoutTermination(TerminationCondition, Component[TimeoutTerminationConfig]): + """Terminate the conversation after a specified duration has passed. + + Args: + timeout_seconds: The maximum duration in seconds before terminating the conversation. + """ + + component_config_schema = TimeoutTerminationConfig + component_provider_override = "agentdhal_agentchat.conditions.TimeoutTermination" + + def __init__(self, timeout_seconds: float) -> None: + self._timeout_seconds = timeout_seconds + self._start_time = time.monotonic() + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + + if (time.monotonic() - self._start_time) >= self._timeout_seconds: + self._terminated = True + return StopMessage( + content=f"Timeout of {self._timeout_seconds} seconds reached", source="TimeoutTermination" + ) + return None + + async def reset(self) -> None: + self._start_time = time.monotonic() + self._terminated = False + + def _to_config(self) -> TimeoutTerminationConfig: + return TimeoutTerminationConfig(timeout_seconds=self._timeout_seconds) + + @classmethod + def _from_config(cls, config: TimeoutTerminationConfig) -> Self: + return cls(timeout_seconds=config.timeout_seconds) + + +class ExternalTerminationConfig(BaseModel): + pass + + +class ExternalTermination(TerminationCondition, Component[ExternalTerminationConfig]): + """A termination condition that is externally controlled + by calling the :meth:`set` method. + + Example: + + .. code-block:: python + + from agentdhal_agentchat.conditions import ExternalTermination + + termination = ExternalTermination() + + # Run the team in an asyncio task. + ... + + # Set the termination condition externally + termination.set() + + """ + + component_config_schema = ExternalTerminationConfig + component_provider_override = "agentdhal_agentchat.conditions.ExternalTermination" + + def __init__(self) -> None: + self._terminated = False + self._setted = False + + @property + def terminated(self) -> bool: + return self._terminated + + def set(self) -> None: + """Set the termination condition to terminated.""" + self._setted = True + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + if self._setted: + self._terminated = True + return StopMessage(content="External termination requested", source="ExternalTermination") + return None + + async def reset(self) -> None: + self._terminated = False + self._setted = False + + def _to_config(self) -> ExternalTerminationConfig: + return ExternalTerminationConfig() + + @classmethod + def _from_config(cls, config: ExternalTerminationConfig) -> Self: + return cls() + + +class SourceMatchTerminationConfig(BaseModel): + sources: List[str] + + +class SourceMatchTermination(TerminationCondition, Component[SourceMatchTerminationConfig]): + """Terminate the conversation after a specific source responds. + + Args: + sources (List[str]): List of source names to terminate the conversation. + + Raises: + TerminatedException: If the termination condition has already been reached. + """ + + component_config_schema = SourceMatchTerminationConfig + component_provider_override = "agentdhal_agentchat.conditions.SourceMatchTermination" + + def __init__(self, sources: List[str]) -> None: + self._sources = sources + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + if not messages: + return None + for message in messages: + if message.source in self._sources: + self._terminated = True + return StopMessage(content=f"'{message.source}' answered", source="SourceMatchTermination") + return None + + async def reset(self) -> None: + self._terminated = False + + def _to_config(self) -> SourceMatchTerminationConfig: + return SourceMatchTerminationConfig(sources=self._sources) + + @classmethod + def _from_config(cls, config: SourceMatchTerminationConfig) -> Self: + return cls(sources=config.sources) + + +class TextMessageTerminationConfig(BaseModel): + """Configuration for the TextMessageTermination termination condition.""" + + source: str | None = None + """The source of the text message to terminate the conversation.""" + + +class TextMessageTermination(TerminationCondition, Component[TextMessageTerminationConfig]): + """Terminate the conversation if a :class:`~agentdhal_agentchat.messages.TextMessage` is received. + + This termination condition checks for TextMessage instances in the message sequence. When a TextMessage is found, + it terminates the conversation if either: + - No source was specified (terminates on any TextMessage) + - The message source matches the specified source + + Args: + source (str | None, optional): The source name to match against incoming messages. If None, matches any source. + Defaults to None. + """ + + component_config_schema = TextMessageTerminationConfig + component_provider_override = "agentdhal_agentchat.conditions.TextMessageTermination" + + def __init__(self, source: str | None = None) -> None: + self._terminated = False + self._source = source + + @property + def terminated(self) -> bool: + return self._terminated + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + for message in messages: + if isinstance(message, TextMessage) and (self._source is None or message.source == self._source): + self._terminated = True + return StopMessage( + content=f"Text message received from '{message.source}'", source="TextMessageTermination" + ) + return None + + async def reset(self) -> None: + self._terminated = False + + def _to_config(self) -> TextMessageTerminationConfig: + return TextMessageTerminationConfig(source=self._source) + + @classmethod + def _from_config(cls, config: TextMessageTerminationConfig) -> Self: + return cls(source=config.source) + + +class FunctionCallTerminationConfig(BaseModel): + """Configuration for the :class:`FunctionCallTermination` termination condition.""" + + function_name: str + + +class FunctionCallTermination(TerminationCondition, Component[FunctionCallTerminationConfig]): + """Terminate the conversation if a :class:`~agentdhal_core.models.FunctionExecutionResult` + with a specific name was received. + + Args: + function_name (str): The name of the function to look for in the messages. + + Raises: + TerminatedException: If the termination condition has already been reached. + """ + + component_config_schema = FunctionCallTerminationConfig + component_provider_override = "agentdhal_agentchat.conditions.FunctionCallTermination" + """The schema for the component configuration.""" + + def __init__(self, function_name: str) -> None: + self._terminated = False + self._function_name = function_name + + @property + def terminated(self) -> bool: + return self._terminated + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + for message in messages: + if isinstance(message, ToolCallExecutionEvent): + for execution in message.content: + if execution.name == self._function_name: + self._terminated = True + return StopMessage( + content=f"Function '{self._function_name}' was executed.", + source="FunctionCallTermination", + ) + return None + + async def reset(self) -> None: + self._terminated = False + + def _to_config(self) -> FunctionCallTerminationConfig: + return FunctionCallTerminationConfig( + function_name=self._function_name, + ) + + @classmethod + def _from_config(cls, config: FunctionCallTerminationConfig) -> Self: + return cls( + function_name=config.function_name, + ) diff --git a/agent_dhal/agentdhal_agentchat/messages.py b/agent_dhal/agentdhal_agentchat/messages.py new file mode 100644 index 0000000..1da4dce --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/messages.py @@ -0,0 +1,693 @@ +""" +This module defines various message types used for agent-to-agent communication. +Each message type inherits either from the BaseChatMessage class or BaseAgentEvent +class and includes specific fields relevant to the type of message being sent. +""" + +import uuid +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from typing import Any, Dict, Generic, List, Literal, Mapping, Optional, Type, TypeVar + +from agentdhal_core import Component, ComponentBase, FunctionCall, Image +from agentdhal_core.code_executor import CodeBlock, CodeResult +from agentdhal_core.memory import MemoryContent +from agentdhal_core.models import ( + FunctionExecutionResult, + LLMMessage, + RequestUsage, + UserMessage, +) +from agentdhal_core.utils import schema_to_pydantic_model +from pydantic import BaseModel, Field, computed_field +from typing_extensions import Annotated, Self + + +class BaseMessage(BaseModel, ABC): + """Abstract base class for all message types in AgentChat. + + .. warning:: + + If you want to create a new message type, do not inherit from this class. + Instead, inherit from :class:`BaseChatMessage` or :class:`BaseAgentEvent` + to clarify the purpose of the message type. + + """ + + @abstractmethod + def to_text(self) -> str: + """Convert the message content to a string-only representation + that can be rendered in the console and inspected by the user or conditions. + This is not used for creating text-only content for models. + For :class:`BaseChatMessage` types, use :meth:`to_model_text` instead.""" + ... + + def dump(self) -> Mapping[str, Any]: + """Convert the message to a JSON-serializable dictionary. + + The default implementation uses the Pydantic model's + :meth:`model_dump` method to convert the message to a dictionary. + Datetime objects are automatically converted to ISO format strings + to ensure JSON serialization compatibility. + Override this method if you want to customize the serialization + process or add additional fields to the output. + """ + return self.model_dump(mode="json") + + @classmethod + def load(cls, data: Mapping[str, Any]) -> Self: + """Create a message from a dictionary of JSON-serializable data. + + The default implementation uses the Pydantic model's + :meth:`model_validate` method to create the message from the data. + Override this method if you want to customize the deserialization + process or add additional fields to the input data.""" + return cls.model_validate(data) + + +class BaseChatMessage(BaseMessage, ABC): + """Abstract base class for chat messages. + + .. note:: + + If you want to create a new message type that is used for agent-to-agent + communication, inherit from this class, or simply use + :class:`StructuredMessage` if your content type is a subclass of + Pydantic BaseModel. + + This class is used for messages that are sent between agents in a chat + conversation. Agents are expected to process the content of the + message using models and return a response as another :class:`BaseChatMessage`. + """ + + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + """Unique identifier for this message.""" + + source: str + """The name of the agent that sent this message.""" + + models_usage: RequestUsage | None = None + """The model client usage incurred when producing this message.""" + + metadata: Dict[str, str] = {} + """Additional metadata about the message.""" + + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + """The time when the message was created.""" + + @abstractmethod + def to_model_text(self) -> str: + """Convert the content of the message to text-only representation. + This is used for creating text-only content for models. + + This is not used for rendering the message in console. For that, use + :meth:`~BaseMessage.to_text`. + + The difference between this and :meth:`to_model_message` is that this + is used to construct parts of the a message for the model client, + while :meth:`to_model_message` is used to create a complete message + for the model client. + """ + ... + + @abstractmethod + def to_model_message(self) -> UserMessage: + """Convert the message content to a :class:`~agentdhal_core.models.UserMessage` + for use with model client, e.g., :class:`~agentdhal_core.models.ChatCompletionClient`. + """ + ... + + +class BaseTextChatMessage(BaseChatMessage, ABC): + """Base class for all text-only :class:`BaseChatMessage` types. + It has implementations for :meth:`to_text`, :meth:`to_model_text`, + and :meth:`to_model_message` methods. + + Inherit from this class if your message content type is a string. + """ + + content: str + """The content of the message.""" + + def to_text(self) -> str: + return self.content + + def to_model_text(self) -> str: + return self.content + + def to_model_message(self) -> UserMessage: + return UserMessage(content=self.content, source=self.source) + + +class BaseAgentEvent(BaseMessage, ABC): + """Base class for agent events. + + .. note:: + + If you want to create a new message type for signaling observable events + to user and application, inherit from this class. + + Agent events are used to signal actions and thoughts produced by agents + and teams to user and applications. They are not used for agent-to-agent + communication and are not expected to be processed by other agents. + + You should override the :meth:`to_text` method if you want to provide + a custom rendering of the content. + """ + + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + """Unique identifier for this event.""" + + source: str + """The name of the agent that sent this message.""" + + models_usage: RequestUsage | None = None + """The model client usage incurred when producing this message.""" + + metadata: Dict[str, str] = {} + """Additional metadata about the message.""" + + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + """The time when the message was created.""" + + +StructuredContentType = TypeVar("StructuredContentType", bound=BaseModel, covariant=True) +"""Type variable for structured content types.""" + + +class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]): + """A :class:`BaseChatMessage` type with an unspecified content type. + + To create a new structured message type, specify the content type + as a subclass of `Pydantic BaseModel `_. + + .. code-block:: python + + from pydantic import BaseModel + from agentdhal_agentchat.messages import StructuredMessage + + + class MyMessageContent(BaseModel): + text: str + number: int + + + message = StructuredMessage[MyMessageContent]( + content=MyMessageContent(text="Hello", number=42), + source="agent1", + ) + + print(message.to_text()) # {"text": "Hello", "number": 42} + + .. code-block:: python + + from pydantic import BaseModel + from agentdhal_agentchat.messages import StructuredMessage + + + class MyMessageContent(BaseModel): + text: str + number: int + + + message = StructuredMessage[MyMessageContent]( + content=MyMessageContent(text="Hello", number=42), + source="agent", + format_string="Hello, {text} {number}!", + ) + + print(message.to_text()) # Hello, agent 42! + + """ + + content: StructuredContentType + """The content of the message. Must be a subclass of + `Pydantic BaseModel `_.""" + + format_string: Optional[str] = None + """(Experimental) An optional format string to render the content into a human-readable format. + The format string can use the fields of the content model as placeholders. + For example, if the content model has a field `name`, you can use + `{name}` in the format string to include the value of that field. + The format string is used in the :meth:`to_text` method to create a + human-readable representation of the message. + This setting is experimental and will change in the future. + """ + + @computed_field + def type(self) -> str: + return self.__class__.__name__ + + def to_text(self) -> str: + if self.format_string is not None: + return self.format_string.format(**self.content.model_dump()) + else: + return self.content.model_dump_json() + + def to_model_text(self) -> str: + if self.format_string is not None: + return self.format_string.format(**self.content.model_dump()) + else: + return self.content.model_dump_json() + + def to_model_message(self) -> UserMessage: + return UserMessage( + content=self.content.model_dump_json(), + source=self.source, + ) + + +class StructureMessageConfig(BaseModel): + """The declarative configuration for the structured output.""" + + json_schema: Dict[str, Any] + format_string: Optional[str] = None + content_model_name: str + + +class StructuredMessageFactory(ComponentBase[StructureMessageConfig], Component[StructureMessageConfig]): + """:meta private: + + A component that creates structured chat messages from Pydantic models or JSON schemas. + + This component helps you generate strongly-typed chat messages with content defined using a Pydantic model. + It can be used in declarative workflows where message structure must be validated, formatted, and serialized. + + You can initialize the component directly using a `BaseModel` subclass, or dynamically from a configuration + object (e.g., loaded from disk or a database). + + ### Example 1: Create from a Pydantic Model + + .. code-block:: python + + from pydantic import BaseModel + from agentdhal_agentchat.messages import StructuredMessageFactory + + + class TestContent(BaseModel): + field1: str + field2: int + + + format_string = "This is a string {field1} and this is an int {field2}" + sm_component = StructuredMessageFactory(input_model=TestContent, format_string=format_string) + + message = sm_component.StructuredMessage( + source="test_agent", content=TestContent(field1="Hello", field2=42), format_string=format_string + ) + + print(message.to_model_text()) # Output: This is a string Hello and this is an int 42 + + config = sm_component.dump_component() + + s_m_dyn = StructuredMessageFactory.load_component(config) + message = s_m_dyn.StructuredMessage( + source="test_agent", + content=s_m_dyn.ContentModel(field1="dyn agent", field2=43), + format_string=s_m_dyn.format_string, + ) + print(type(message)) # StructuredMessage[GeneratedModel] + print(message.to_model_text()) # Output: This is a string dyn agent and this is an int 43 + + Attributes: + component_config_schema (StructureMessageConfig): Defines the configuration structure for this component. + component_provider_override (str): Path used to reference this component in external tooling. + component_type (str): Identifier used for categorization (e.g., "structured_message"). + + Raises: + ValueError: If neither `json_schema` nor `input_model` is provided. + + Args: + json_schema (Optional[str]): JSON schema to dynamically create a Pydantic model. + input_model (Optional[Type[BaseModel]]): A subclass of `BaseModel` that defines the expected message structure. + format_string (Optional[str]): Optional string to render content into a human-readable format. + content_model_name (Optional[str]): Optional name for the generated Pydantic model. + """ + + component_config_schema = StructureMessageConfig + component_provider_override = "agentdhal_agentchat.messages.StructuredMessageFactory" + component_type = "structured_message" + + def __init__( + self, + json_schema: Optional[Dict[str, Any]] = None, + input_model: Optional[Type[BaseModel]] = None, + format_string: Optional[str] = None, + content_model_name: Optional[str] = None, + ) -> None: + self.format_string = format_string + + if json_schema: + self.ContentModel = schema_to_pydantic_model( + json_schema, model_name=content_model_name or "GeneratedContentModel" + ) + elif input_model: + self.ContentModel = input_model + else: + raise ValueError("Either `json_schema` or `input_model` must be provided.") + + self.StructuredMessage = StructuredMessage[self.ContentModel] # type: ignore[name-defined] + + def _to_config(self) -> StructureMessageConfig: + return StructureMessageConfig( + json_schema=self.ContentModel.model_json_schema(), + format_string=self.format_string, + content_model_name=self.ContentModel.__name__, + ) + + @classmethod + def _from_config(cls, config: StructureMessageConfig) -> "StructuredMessageFactory": + return cls( + json_schema=config.json_schema, + format_string=config.format_string, + content_model_name=config.content_model_name, + ) + + +class TextMessage(BaseTextChatMessage): + """A text message with string-only content.""" + + type: Literal["TextMessage"] = "TextMessage" + + +class MultiModalMessage(BaseChatMessage): + """A multimodal message.""" + + content: List[str | Image] + """The content of the message.""" + + type: Literal["MultiModalMessage"] = "MultiModalMessage" + + def to_model_text(self, image_placeholder: str | None = "[image]") -> str: + """Convert the content of the message to a string-only representation. + If an image is present, it will be replaced with the image placeholder + by default, otherwise it will be a base64 string when set to None. + """ + text = "" + for c in self.content: + if isinstance(c, str): + text += c + elif isinstance(c, Image): + if image_placeholder is not None: + text += f" {image_placeholder}" + else: + text += f" {c.to_base64()}" + return text + + def to_text(self, iterm: bool = False) -> str: + result: List[str] = [] + for c in self.content: + if isinstance(c, str): + result.append(c) + else: + if iterm: + # iTerm2 image rendering protocol: https://iterm2.com/documentation-images.html + image_data = c.to_base64() + result.append(f"\033]1337;File=inline=1:{image_data}\a\n") + else: + result.append("") + return "\n".join(result) + + def to_model_message(self) -> UserMessage: + return UserMessage(content=self.content, source=self.source) + + +class StopMessage(BaseTextChatMessage): + """A message requesting stop of a conversation.""" + + type: Literal["StopMessage"] = "StopMessage" + + +class HandoffMessage(BaseTextChatMessage): + """A message requesting handoff of a conversation to another agent.""" + + target: str + """The name of the target agent to handoff to.""" + + context: List[LLMMessage] = [] + """The model context to be passed to the target agent.""" + + type: Literal["HandoffMessage"] = "HandoffMessage" + + +class ToolCallSummaryMessage(BaseTextChatMessage): + """A message signaling the summary of tool call results.""" + + type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage" + + tool_calls: List[FunctionCall] + """The tool calls that were made.""" + + results: List[FunctionExecutionResult] + """The results of the tool calls.""" + + +class ToolCallRequestEvent(BaseAgentEvent): + """An event signaling a request to use tools.""" + + content: List[FunctionCall] + """The tool calls.""" + + type: Literal["ToolCallRequestEvent"] = "ToolCallRequestEvent" + + def to_text(self) -> str: + return str(self.content) + + +class CodeGenerationEvent(BaseAgentEvent): + """An event signaling code generation event.""" + + retry_attempt: int + "Retry number, 0 means first generation" + + content: str + "The complete content as string." + + code_blocks: List[CodeBlock] + "List of code blocks present in content" + + type: Literal["CodeGenerationEvent"] = "CodeGenerationEvent" + + def to_text(self) -> str: + return self.content + + +class CodeExecutionEvent(BaseAgentEvent): + """An event signaling code execution event.""" + + retry_attempt: int + "Retry number, 0 means first execution" + + result: CodeResult + "Code Execution Result" + + type: Literal["CodeExecutionEvent"] = "CodeExecutionEvent" + + def to_text(self) -> str: + return self.result.output + + +class ToolCallExecutionEvent(BaseAgentEvent): + """An event signaling the execution of tool calls.""" + + content: List[FunctionExecutionResult] + """The tool call results.""" + + type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent" + + def to_text(self) -> str: + return str(self.content) + + +class UserInputRequestedEvent(BaseAgentEvent): + """An event signaling a that the user proxy has requested user input. Published prior to invoking the input callback.""" + + request_id: str + """Identifier for the user input request.""" + + content: Literal[""] = "" + """Empty content for compat with consumers expecting a content field.""" + + type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent" + + def to_text(self) -> str: + return str(self.content) + + +class MemoryQueryEvent(BaseAgentEvent): + """An event signaling the results of memory queries.""" + + content: List[MemoryContent] + """The memory query results.""" + + type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent" + + def to_text(self) -> str: + return str(self.content) + + +class ModelClientStreamingChunkEvent(BaseAgentEvent): + """An event signaling a text output chunk from a model client in streaming mode.""" + + content: str + """A string chunk from the model client.""" + + full_message_id: str | None = None + """Optional reference to the complete message that may come after the chunks. + This allows consumers of the stream to correlate chunks with the eventual completed message.""" + + type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent" + + def to_text(self) -> str: + return self.content + + +class ThoughtEvent(BaseAgentEvent): + """An event signaling the thought process of a model. + It is used to communicate the reasoning tokens generated by a reasoning model, + or the extra text content generated by a function call.""" + + content: str + """The thought process of the model.""" + + type: Literal["ThoughtEvent"] = "ThoughtEvent" + + def to_text(self) -> str: + return self.content + + +class SelectSpeakerEvent(BaseAgentEvent): + """An event signaling the selection of speakers for a conversation.""" + + content: List[str] + """The names of the selected speakers.""" + + type: Literal["SelectSpeakerEvent"] = "SelectSpeakerEvent" + + def to_text(self) -> str: + return str(self.content) + + +class SelectorEvent(BaseAgentEvent): + """An event emitted from the `SelectorGroupChat`.""" + + content: str + """The content of the event.""" + + type: Literal["SelectorEvent"] = "SelectorEvent" + + def to_text(self) -> str: + return str(self.content) + + +class MessageFactory: + """:meta private: + + A factory for creating messages from JSON-serializable dictionaries. + + This is useful for deserializing messages from JSON data. + """ + + def __init__(self) -> None: + self._message_types: Dict[str, type[BaseAgentEvent | BaseChatMessage]] = {} + # Register all message types. + self._message_types[TextMessage.__name__] = TextMessage + self._message_types[MultiModalMessage.__name__] = MultiModalMessage + self._message_types[StopMessage.__name__] = StopMessage + self._message_types[ToolCallSummaryMessage.__name__] = ToolCallSummaryMessage + self._message_types[HandoffMessage.__name__] = HandoffMessage + self._message_types[ToolCallRequestEvent.__name__] = ToolCallRequestEvent + self._message_types[ToolCallExecutionEvent.__name__] = ToolCallExecutionEvent + self._message_types[MemoryQueryEvent.__name__] = MemoryQueryEvent + self._message_types[UserInputRequestedEvent.__name__] = UserInputRequestedEvent + self._message_types[ModelClientStreamingChunkEvent.__name__] = ModelClientStreamingChunkEvent + self._message_types[ThoughtEvent.__name__] = ThoughtEvent + self._message_types[SelectSpeakerEvent.__name__] = SelectSpeakerEvent + self._message_types[CodeGenerationEvent.__name__] = CodeGenerationEvent + self._message_types[CodeExecutionEvent.__name__] = CodeExecutionEvent + + def is_registered(self, message_type: type[BaseAgentEvent | BaseChatMessage]) -> bool: + """Check if a message type is registered with the factory.""" + # Get the class name of the message type. + class_name = message_type.__name__ + # Check if the class name is already registered. + return class_name in self._message_types + + def register(self, message_type: type[BaseAgentEvent | BaseChatMessage]) -> None: + """Register a new message type with the factory.""" + if self.is_registered(message_type): + raise ValueError(f"Message type {message_type} is already registered.") + if not issubclass(message_type, BaseChatMessage) and not issubclass(message_type, BaseAgentEvent): + raise ValueError(f"Message type {message_type} must be a subclass of BaseChatMessage or BaseAgentEvent.") + # Get the class name of the + class_name = message_type.__name__ + # Check if the class name is already registered. + # Register the message type. + self._message_types[class_name] = message_type + + def create(self, data: Mapping[str, Any]) -> BaseAgentEvent | BaseChatMessage: + """Create a message from a dictionary of JSON-serializable data.""" + # Get the type of the message from the dictionary. + message_type = data.get("type") + if message_type is None: + raise ValueError("Field 'type' is required in the message data to recover the message type.") + if message_type not in self._message_types: + raise ValueError(f"Unknown message type: {message_type}") + if not isinstance(message_type, str): + raise ValueError(f"Message type must be a string, got {type(message_type)}") + + # Get the class for the message type. + message_class = self._message_types[message_type] + + # Create an instance of the message class. + assert issubclass(message_class, BaseChatMessage) or issubclass(message_class, BaseAgentEvent) + return message_class.load(data) + + +ChatMessage = Annotated[ + TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, + Field(discriminator="type"), +] +"""The union type of all built-in concrete subclasses of :class:`BaseChatMessage`. +It does not include :class:`StructuredMessage` types.""" + +AgentEvent = Annotated[ + ToolCallRequestEvent + | ToolCallExecutionEvent + | MemoryQueryEvent + | UserInputRequestedEvent + | ModelClientStreamingChunkEvent + | ThoughtEvent + | SelectSpeakerEvent + | CodeGenerationEvent + | CodeExecutionEvent, + Field(discriminator="type"), +] +"""The union type of all built-in concrete subclasses of :class:`BaseAgentEvent`.""" + +__all__ = [ + "AgentEvent", + "BaseMessage", + "ChatMessage", + "BaseChatMessage", + "BaseAgentEvent", + "BaseTextChatMessage", + "StructuredContentType", + "StructuredMessage", + "StructuredMessageFactory", + "HandoffMessage", + "MultiModalMessage", + "StopMessage", + "TextMessage", + "ToolCallExecutionEvent", + "ToolCallRequestEvent", + "ToolCallSummaryMessage", + "MemoryQueryEvent", + "UserInputRequestedEvent", + "ModelClientStreamingChunkEvent", + "ThoughtEvent", + "SelectSpeakerEvent", + "MessageFactory", + "CodeGenerationEvent", + "CodeExecutionEvent", +] diff --git a/agent_dhal/agentdhal_agentchat/py.typed b/agent_dhal/agentdhal_agentchat/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/agent_dhal/agentdhal_agentchat/state/__init__.py b/agent_dhal/agentdhal_agentchat/state/__init__.py new file mode 100644 index 0000000..3cb3efa --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/state/__init__.py @@ -0,0 +1,27 @@ +"""State management for agents, teams and termination conditions.""" + +from ._states import ( + AssistantAgentState, + BaseGroupChatManagerState, + BaseState, + ChatAgentContainerState, + MagenticOneOrchestratorState, + RoundRobinManagerState, + SelectorManagerState, + SocietyOfMindAgentState, + SwarmManagerState, + TeamState, +) + +__all__ = [ + "BaseState", + "AssistantAgentState", + "BaseGroupChatManagerState", + "ChatAgentContainerState", + "RoundRobinManagerState", + "SelectorManagerState", + "SwarmManagerState", + "MagenticOneOrchestratorState", + "TeamState", + "SocietyOfMindAgentState", +] diff --git a/agent_dhal/agentdhal_agentchat/state/_states.py b/agent_dhal/agentdhal_agentchat/state/_states.py new file mode 100644 index 0000000..a27199a --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/state/_states.py @@ -0,0 +1,79 @@ +from typing import Any, List, Mapping, Optional + +from pydantic import BaseModel, Field + + +class BaseState(BaseModel): + """Base class for all saveable state""" + + type: str = Field(default="BaseState") + version: str = Field(default="1.0.0") + + +class HalState(BaseState): + """State for an assistant agent.""" + + llm_context: Mapping[str, Any] = Field(default_factory=lambda: dict([("messages", [])])) + type: str = Field(default="AssistantAgentState") + + +class TeamState(BaseState): + """State for a team of agents.""" + + agent_states: Mapping[str, Any] = Field(default_factory=dict) + type: str = Field(default="TeamState") + + +class BaseGroupChatManagerState(BaseState): + """Base state for all group chat managers.""" + + message_thread: List[Mapping[str, Any]] = Field(default_factory=list) + current_turn: int = Field(default=0) + type: str = Field(default="BaseGroupChatManagerState") + + +class ChatAgentContainerState(BaseState): + """State for a container of chat agents.""" + + agent_state: Mapping[str, Any] = Field(default_factory=dict) + message_buffer: List[Mapping[str, Any]] = Field(default_factory=list) + type: str = Field(default="ChatAgentContainerState") + + +class RoundRobinManagerState(BaseGroupChatManagerState): + """State for :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` manager.""" + + next_speaker_index: int = Field(default=0) + type: str = Field(default="RoundRobinManagerState") + + +class SelectorManagerState(BaseGroupChatManagerState): + """State for :class:`~agentdhal_agentchat.teams.SelectorGroupChat` manager.""" + + previous_speaker: Optional[str] = Field(default=None) + type: str = Field(default="SelectorManagerState") + + +class SwarmManagerState(BaseGroupChatManagerState): + """State for :class:`~agentdhal_agentchat.teams.Swarm` manager.""" + + current_speaker: str = Field(default="") + type: str = Field(default="SwarmManagerState") + + +class MagenticOneOrchestratorState(BaseGroupChatManagerState): + """State for :class:`~agentdhal_agentchat.teams.MagneticOneGroupChat` orchestrator.""" + + task: str = Field(default="") + facts: str = Field(default="") + plan: str = Field(default="") + n_rounds: int = Field(default=0) + n_stalls: int = Field(default=0) + type: str = Field(default="MagenticOneOrchestratorState") + + +class SocietyOfMindAgentState(BaseState): + """State for a Society of Mind agent.""" + + inner_team_state: Mapping[str, Any] = Field(default_factory=dict) + type: str = Field(default="SocietyOfMindAgentState") diff --git a/agent_dhal/agentdhal_agentchat/teams/__init__.py b/agent_dhal/agentdhal_agentchat/teams/__init__.py new file mode 100644 index 0000000..7129769 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/__init__.py @@ -0,0 +1,30 @@ +""" +This module provides implementation of various pre-defined multi-agent teams. +Each team inherits from the BaseGroupChat class. +""" + +from ._group_chat._base_group_chat import BaseGroupChat +from ._group_chat._graph import ( + DiGraph, + DiGraphBuilder, + DiGraphEdge, + DiGraphNode, + GraphFlow, +) +from ._group_chat._magentic_one import MagenticOneGroupChat +from ._group_chat._round_robin_group_chat import RoundRobinGroupChat +from ._group_chat._selector_group_chat import SelectorGroupChat +from ._group_chat._swarm_group_chat import Swarm + +__all__ = [ + "BaseGroupChat", + "RoundRobinGroupChat", + "SelectorGroupChat", + "Swarm", + "MagenticOneGroupChat", + "DiGraphBuilder", + "DiGraph", + "DiGraphNode", + "DiGraphEdge", + "GraphFlow", +] diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/__init__.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_base_group_chat.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_base_group_chat.py new file mode 100644 index 0000000..bad66c8 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_base_group_chat.py @@ -0,0 +1,834 @@ +import asyncio +import uuid +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence + +from agentdhal_core import ( + AgentId, + AgentRuntime, + AgentType, + CancellationToken, + ComponentBase, + SingleThreadedAgentRuntime, + TypeSubscription, +) +from pydantic import BaseModel, ValidationError + +from ...base import ChatAgent, TaskResult, Team, TerminationCondition +from ...messages import ( + BaseAgentEvent, + BaseChatMessage, + MessageFactory, + ModelClientStreamingChunkEvent, + StopMessage, + StructuredMessage, + TextMessage, +) +from ...state import TeamState +from ._chat_agent_container import ChatAgentContainer +from ._events import ( + GroupChatPause, + GroupChatReset, + GroupChatResume, + GroupChatStart, + GroupChatTermination, + SerializableException, +) +from ._sequential_routed_agent import SequentialRoutedAgent + + +class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]): + """The base class for group chat teams. + + In a group chat team, participants share context by publishing their messages + to all other participants. + + If an :class:`~agentdhal_agentchat.base.ChatAgent` is a participant, + the :class:`~agentdhal_agentchat.messages.BaseChatMessage` from the agent response's + :attr:`~agentdhal_agentchat.base.Response.chat_message` will be published + to other participants in the group chat. + + If a :class:`~agentdhal_agentchat.base.Team` is a participant, + the :class:`~agentdhal_agentchat.messages.BaseChatMessage` + from the team result' :attr:`~agentdhal_agentchat.base.TaskResult.messages` will be published + to other participants in the group chat. + + To implement a group chat team, first create a subclass of :class:`BaseGroupChatManager` and then + create a subclass of :class:`BaseGroupChat` that uses the group chat manager. + + This base class provides the mapping between the agents of the AgentChat API + and the agent runtime of the Core API, and handles high-level features like + running, pausing, resuming, and resetting the team. + """ + + component_type = "team" + + def __init__( + self, + name: str, + description: str, + participants: List[ChatAgent | Team], + group_chat_manager_name: str, + group_chat_manager_class: type[SequentialRoutedAgent], + termination_condition: TerminationCondition | None = None, + max_turns: int | None = None, + runtime: AgentRuntime | None = None, + custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, + emit_team_events: bool = False, + ): + self._name = name + self._description = description + if len(participants) == 0: + raise ValueError("At least one participant is required.") + if len(participants) != len(set(participant.name for participant in participants)): + raise ValueError("The participant names must be unique.") + self._participants = participants + self._base_group_chat_manager_class = group_chat_manager_class + self._termination_condition = termination_condition + self._max_turns = max_turns + self._message_factory = MessageFactory() + if custom_message_types is not None: + for message_type in custom_message_types: + self._message_factory.register(message_type) + + for agent in participants: + if isinstance(agent, ChatAgent): + for message_type in agent.produced_message_types: + try: + is_registered = self._message_factory.is_registered(message_type) # type: ignore[reportUnknownArgumentType] + if issubclass(message_type, StructuredMessage) and not is_registered: + self._message_factory.register(message_type) # type: ignore[reportUnknownArgumentType] + except TypeError: + # Not a class or not a valid subclassable type (skip) + pass + + # The team ID is a UUID that is used to identify the team and its participants + # in the agent runtime. It is used to create unique topic types for each participant. + # Currently, team ID is binded to an object instance of the group chat class. + # So if you create two instances of group chat, there will be two teams with different IDs. + self._team_id = str(uuid.uuid4()) + + # Constants for the group chat team. + # The names are used to identify the agents within the team. + # The names may not be unique across different teams. + self._group_chat_manager_name = group_chat_manager_name + self._participant_names: List[str] = [participant.name for participant in participants] + self._participant_descriptions: List[str] = [participant.description for participant in participants] + # The group chat topic type is used for broadcast communication among all participants and the group chat manager. + self._group_topic_type = f"group_topic_{self._team_id}" + # The group chat manager topic type is used for direct communication with the group chat manager. + self._group_chat_manager_topic_type = f"{self._group_chat_manager_name}_{self._team_id}" + # The participant topic types are used for direct communication with each participant. + self._participant_topic_types: List[str] = [ + f"{participant.name}_{self._team_id}" for participant in participants + ] + # The output topic type is used for emitting streaming messages from the group chat. + # The group chat manager will relay the messages to the output message queue. + self._output_topic_type = f"output_topic_{self._team_id}" + + # The queue for collecting the output messages. + self._output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = ( + asyncio.Queue() + ) + + # Create a runtime for the team. + if runtime is not None: + self._runtime = runtime + self._embedded_runtime = False + else: + # Use a embedded single-threaded runtime for the group chat. + # Background exceptions must not be ignored as it results in non-surfaced exceptions and early team termination. + self._runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) + self._embedded_runtime = True + + # Flag to track if the group chat has been initialized. + self._initialized = False + + # Flag to track if the group chat is running. + self._is_running = False + + # Flag to track if the team events should be emitted. + self._emit_team_events = emit_team_events + + @property + def name(self) -> str: + """The name of the group chat team.""" + return self._name + + @property + def description(self) -> str: + """A description of the group chat team.""" + return self._description + + @abstractmethod + def _create_group_chat_manager_factory( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + max_turns: int | None, + message_factory: MessageFactory, + ) -> Callable[[], SequentialRoutedAgent]: ... + + def _create_participant_factory( + self, + parent_topic_type: str, + output_topic_type: str, + agent: ChatAgent | Team, + message_factory: MessageFactory, + ) -> Callable[[], ChatAgentContainer]: + def _factory() -> ChatAgentContainer: + container = ChatAgentContainer(parent_topic_type, output_topic_type, agent, message_factory) + return container + + return _factory + + async def _init(self, runtime: AgentRuntime) -> None: + # Constants for the group chat manager. + group_chat_manager_agent_type = AgentType(self._group_chat_manager_topic_type) + + # Register participants. + # Use the participant topic type as the agent type. + for participant, agent_type in zip(self._participants, self._participant_topic_types, strict=True): + # Register the participant factory. + await ChatAgentContainer.register( + runtime, + type=agent_type, + factory=self._create_participant_factory( + self._group_topic_type, self._output_topic_type, participant, self._message_factory + ), + ) + # Add subscriptions for the participant. + # The participant should be able to receive messages from its own topic. + await runtime.add_subscription(TypeSubscription(topic_type=agent_type, agent_type=agent_type)) + # The participant should be able to receive messages from the group topic. + await runtime.add_subscription(TypeSubscription(topic_type=self._group_topic_type, agent_type=agent_type)) + + # Register the group chat manager. + await self._base_group_chat_manager_class.register( + runtime, + type=group_chat_manager_agent_type.type, + factory=self._create_group_chat_manager_factory( + name=self._group_chat_manager_name, + group_topic_type=self._group_topic_type, + output_topic_type=self._output_topic_type, + participant_names=self._participant_names, + participant_topic_types=self._participant_topic_types, + participant_descriptions=self._participant_descriptions, + output_message_queue=self._output_message_queue, + termination_condition=self._termination_condition, + max_turns=self._max_turns, + message_factory=self._message_factory, + ), + ) + # Add subscriptions for the group chat manager. + # The group chat manager should be able to receive messages from the its own topic. + await runtime.add_subscription( + TypeSubscription( + topic_type=self._group_chat_manager_topic_type, agent_type=group_chat_manager_agent_type.type + ) + ) + # The group chat manager should be able to receive messages from the group topic. + await runtime.add_subscription( + TypeSubscription(topic_type=self._group_topic_type, agent_type=group_chat_manager_agent_type.type) + ) + # The group chat manager will relay the messages from output topic to the output message queue. + await runtime.add_subscription( + TypeSubscription(topic_type=self._output_topic_type, agent_type=group_chat_manager_agent_type.type) + ) + + self._initialized = True + + async def run( + self, + *, + task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None, + cancellation_token: CancellationToken | None = None, + output_task_messages: bool = True, + ) -> TaskResult: + """Run the team and return the result. The base implementation uses + :meth:`run_stream` to run the team and then returns the final result. + Once the team is stopped, the termination condition is reset. + + Args: + task (str | BaseChatMessage | Sequence[BaseChatMessage] | None): The task to run the team with. Can be a string, a single :class:`BaseChatMessage` , or a list of :class:`BaseChatMessage`. + cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. + Setting the cancellation token potentially put the team in an inconsistent state, + and it may not reset the termination condition. + To gracefully stop the team, use :class:`~agentdhal_agentchat.conditions.ExternalTermination` instead. + + Returns: + result: The result of the task as :class:`~agentdhal_agentchat.base.TaskResult`. The result contains the messages produced by the team and the stop reason. + + Example using the :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` team: + + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import MaxMessageTermination + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + termination = MaxMessageTermination(3) + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + + result = await team.run(task="Count from 1 to 10, respond one at a time.") + print(result) + + # Run the team again without a task to continue the previous task. + result = await team.run() + print(result) + + + asyncio.run(main()) + + + Example using the :class:`~agentdhal_core.CancellationToken` to cancel the task: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import MaxMessageTermination + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_core import CancellationToken + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + termination = MaxMessageTermination(3) + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + + cancellation_token = CancellationToken() + + # Create a task to run the team in the background. + run_task = asyncio.create_task( + team.run( + task="Count from 1 to 10, respond one at a time.", + cancellation_token=cancellation_token, + ) + ) + + # Wait for 1 second and then cancel the task. + await asyncio.sleep(1) + cancellation_token.cancel() + + # This will raise a cancellation error. + await run_task + + + asyncio.run(main()) + """ + result: TaskResult | None = None + async for message in self.run_stream( + task=task, + cancellation_token=cancellation_token, + output_task_messages=output_task_messages, + ): + if isinstance(message, TaskResult): + result = message + if result is not None: + return result + raise AssertionError("The stream should have returned the final result.") + + async def run_stream( + self, + *, + task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None, + cancellation_token: CancellationToken | None = None, + output_task_messages: bool = True, + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]: + """Run the team and produces a stream of messages and the final result + of the type :class:`~agentdhal_agentchat.base.TaskResult` as the last item in the stream. Once the + team is stopped, the termination condition is reset. + + .. note:: + + If an agent produces :class:`~agentdhal_agentchat.messages.ModelClientStreamingChunkEvent`, + the message will be yielded in the stream but it will not be included in the + :attr:`~agentdhal_agentchat.base.TaskResult.messages`. + + Args: + task (str | BaseChatMessage | Sequence[BaseChatMessage] | None): The task to run the team with. Can be a string, a single :class:`BaseChatMessage` , or a list of :class:`BaseChatMessage`. + cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. + Setting the cancellation token potentially put the team in an inconsistent state, + and it may not reset the termination condition. + To gracefully stop the team, use :class:`~agentdhal_agentchat.conditions.ExternalTermination` instead. + output_task_messages (bool): Whether to include task messages in the output stream. Defaults to True for backward compatibility. + + Returns: + stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~agentdhal_agentchat.messages.BaseAgentEvent`, :class:`~agentdhal_agentchat.messages.BaseChatMessage`, and the final result :class:`~agentdhal_agentchat.base.TaskResult` as the last item in the stream. + + Example using the :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` team: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import MaxMessageTermination + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + termination = MaxMessageTermination(3) + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + + stream = team.run_stream(task="Count from 1 to 10, respond one at a time.") + async for message in stream: + print(message) + + # Run the team again without a task to continue the previous task. + stream = team.run_stream() + async for message in stream: + print(message) + + + asyncio.run(main()) + + + Example using the :class:`~agentdhal_core.CancellationToken` to cancel the task: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import MaxMessageTermination + from agentdhal_agentchat.ui import Console + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_core import CancellationToken + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + termination = MaxMessageTermination(3) + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + + cancellation_token = CancellationToken() + + # Create a task to run the team in the background. + run_task = asyncio.create_task( + Console( + team.run_stream( + task="Count from 1 to 10, respond one at a time.", + cancellation_token=cancellation_token, + ) + ) + ) + + # Wait for 1 second and then cancel the task. + await asyncio.sleep(1) + cancellation_token.cancel() + + # This will raise a cancellation error. + await run_task + + + asyncio.run(main()) + + """ + # Create the messages list if the task is a string or a chat message. + messages: List[BaseChatMessage] | None = None + if task is None: + pass + elif isinstance(task, str): + messages = [TextMessage(content=task, source="user")] + elif isinstance(task, BaseChatMessage): + messages = [task] + elif isinstance(task, list): + if not task: + raise ValueError("Task list cannot be empty.") + messages = [] + for msg in task: + if not isinstance(msg, BaseChatMessage): + raise ValueError("All messages in task list must be valid BaseChatMessage types") + messages.append(msg) + else: + raise ValueError("Task must be a string, a BaseChatMessage, or a list of BaseChatMessage.") + # Check if the messages types are registered with the message factory. + if messages is not None: + for msg in messages: + if not self._message_factory.is_registered(msg.__class__): + raise ValueError( + f"Message type {msg.__class__} is not registered with the message factory. " + "Please register it with the message factory by adding it to the " + "custom_message_types list when creating the team." + ) + + if self._is_running: + raise ValueError("The team is already running, it cannot run again until it is stopped.") + self._is_running = True + + if self._embedded_runtime: + # Start the embedded runtime. + assert isinstance(self._runtime, SingleThreadedAgentRuntime) + self._runtime.start() + + if not self._initialized: + await self._init(self._runtime) + + shutdown_task: asyncio.Task[None] | None = None + if self._embedded_runtime: + + async def stop_runtime() -> None: + assert isinstance(self._runtime, SingleThreadedAgentRuntime) + try: + # This will propagate any exceptions raised. + await self._runtime.stop_when_idle() + # Put a termination message in the queue to indicate that the group chat is stopped for whatever reason + # but not due to an exception. + await self._output_message_queue.put( + GroupChatTermination( + message=StopMessage( + content="The group chat is stopped.", source=self._group_chat_manager_name + ) + ) + ) + except Exception as e: + # Stop the consumption of messages and end the stream. + # NOTE: we also need to put a GroupChatTermination event here because when the runtime + # has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue. + # This may not be necessary if the group chat manager is able to handle the exception and put the event in the queue. + await self._output_message_queue.put( + GroupChatTermination( + message=StopMessage( + content="An exception occurred in the runtime.", source=self._group_chat_manager_name + ), + error=SerializableException.from_exception(e), + ) + ) + + # Create a background task to stop the runtime when the group chat + # is stopped or has an exception. + shutdown_task = asyncio.create_task(stop_runtime()) + + try: + # Run the team by sending the start message to the group chat manager. + # The group chat manager will start the group chat by relaying the message to the participants + # and the group chat manager. + await self._runtime.send_message( + GroupChatStart(messages=messages, output_task_messages=output_task_messages), + recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), + cancellation_token=cancellation_token, + ) + # Collect the output messages in order. + output_messages: List[BaseAgentEvent | BaseChatMessage] = [] + stop_reason: str | None = None + + # Yield the messages until the queue is empty. + while True: + message_future = asyncio.ensure_future(self._output_message_queue.get()) + if cancellation_token is not None: + cancellation_token.link_future(message_future) + # Wait for the next message, this will raise an exception if the task is cancelled. + message = await message_future + if isinstance(message, GroupChatTermination): + # If the message contains an error, we need to raise it here. + # This will stop the team and propagate the error. + if message.error is not None: + raise RuntimeError(str(message.error)) + stop_reason = message.message.content + break + yield message + if isinstance(message, ModelClientStreamingChunkEvent): + # Skip the model client streaming chunk events. + continue + output_messages.append(message) + + # Yield the final result. + yield TaskResult(messages=output_messages, stop_reason=stop_reason) + + finally: + try: + if shutdown_task is not None: + # Wait for the shutdown task to finish. + # This will propagate any exceptions raised. + await shutdown_task + finally: + # Clear the output message queue. + while not self._output_message_queue.empty(): + self._output_message_queue.get_nowait() + + # Indicate that the team is no longer running. + self._is_running = False + + async def reset(self) -> None: + """Reset the team and its participants to their initial state. + + The team must be stopped before it can be reset. + + Raises: + RuntimeError: If the team has not been initialized or is currently running. + + Example using the :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` team: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import MaxMessageTermination + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + termination = MaxMessageTermination(3) + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + stream = team.run_stream(task="Count from 1 to 10, respond one at a time.") + async for message in stream: + print(message) + + # Reset the team. + await team.reset() + stream = team.run_stream(task="Count from 1 to 10, respond one at a time.") + async for message in stream: + print(message) + + + asyncio.run(main()) + """ + + if not self._initialized: + await self._init(self._runtime) + + if self._is_running: + raise RuntimeError("The group chat is currently running. It must be stopped before it can be reset.") + self._is_running = True + + if self._embedded_runtime: + # Start the runtime. + assert isinstance(self._runtime, SingleThreadedAgentRuntime) + self._runtime.start() + + try: + # Send a reset messages to all participants. + for participant_topic_type in self._participant_topic_types: + await self._runtime.send_message( + GroupChatReset(), + recipient=AgentId(type=participant_topic_type, key=self._team_id), + ) + # Send a reset message to the group chat manager. + await self._runtime.send_message( + GroupChatReset(), + recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), + ) + finally: + if self._embedded_runtime: + # Stop the runtime. + assert isinstance(self._runtime, SingleThreadedAgentRuntime) + await self._runtime.stop_when_idle() + + # Reset the output message queue. + while not self._output_message_queue.empty(): + self._output_message_queue.get_nowait() + + # Indicate that the team is no longer running. + self._is_running = False + + async def pause(self) -> None: + """Pause its participants when the team is running by calling their + :meth:`~agentdhal_agentchat.base.ChatAgent.on_pause` method via direct RPC calls. + + .. attention:: + + This is an experimental feature introduced in v0.4.9 and may subject + to change or removal in the future. + + The team must be initialized before it can be paused. + + Different from termination, pausing the team does not cause the + :meth:`run` or :meth:`run_stream` method to return. It calls the + :meth:`~agentdhal_agentchat.base.ChatAgent.on_pause` method on each + participant, and if the participant does not implement the method, it + will be a no-op. + + .. note:: + + It is the responsibility of the agent class to handle the pause + and ensure that the agent can be resumed later. + Make sure to implement the :meth:`~agentdhal_agentchat.agents.BaseChatAgent.on_pause` + method in your agent class for custom pause behavior. + By default, the agent will not do anything when called. + + Raises: + RuntimeError: If the team has not been initialized. Exceptions from + the participants when calling their implementations of + :class:`~agentdhal_agentchat.base.ChatAgent.on_pause` are + propagated to this method and raised. + """ + if not self._initialized: + raise RuntimeError("The group chat has not been initialized. It must be run before it can be paused.") + + # Send a pause message to all participants. + for participant_topic_type in self._participant_topic_types: + await self._runtime.send_message( + GroupChatPause(), + recipient=AgentId(type=participant_topic_type, key=self._team_id), + ) + # Send a pause message to the group chat manager. + await self._runtime.send_message( + GroupChatPause(), + recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), + ) + + async def resume(self) -> None: + """Resume its participants when the team is running and paused by calling their + :meth:`~agentdhal_agentchat.base.ChatAgent.on_resume` method via direct RPC calls. + + .. attention:: + + This is an experimental feature introduced in v0.4.9 and may subject + to change or removal in the future. + + The team must be initialized before it can be resumed. + + Different from termination and restart with a new task, resuming the team + does not cause the :meth:`run` or :meth:`run_stream` method to return. + It calls the :meth:`~agentdhal_agentchat.base.ChatAgent.on_resume` method on each + participant, and if the participant does not implement the method, it + will be a no-op. + + .. note:: + + It is the responsibility of the agent class to handle the resume + and ensure that the agent continues from where it was paused. + Make sure to implement the :meth:`~agentdhal_agentchat.agents.BaseChatAgent.on_resume` + method in your agent class for custom resume behavior. + + Raises: + RuntimeError: If the team has not been initialized. Exceptions from + the participants when calling their implementations of :class:`~agentdhal_agentchat.base.ChatAgent.on_resume` + method are propagated to this method and raised. + + """ + if not self._initialized: + raise RuntimeError("The group chat has not been initialized. It must be run before it can be resumed.") + + # Send a resume message to all participants. + for participant_topic_type in self._participant_topic_types: + await self._runtime.send_message( + GroupChatResume(), + recipient=AgentId(type=participant_topic_type, key=self._team_id), + ) + # Send a resume message to the group chat manager. + await self._runtime.send_message( + GroupChatResume(), + recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), + ) + + async def save_state(self) -> Mapping[str, Any]: + """Save the state of the group chat team. + + The state is saved by calling the :meth:`~agentdhal_core.AgentRuntime.agent_save_state` method + on each participant and the group chat manager with their internal agent ID. + The state is returned as a nested dictionary: a dictionary with key `agent_states`, + which is a dictionary the agent names as keys and the state as values. + + .. code-block:: text + + { + "agent_states": { + "agent1": ..., + "agent2": ..., + "RoundRobinGroupChatManager": ... + } + } + + .. note:: + + Starting v0.4.9, the state is using the agent name as the key instead of the agent ID, + and the `team_id` field is removed from the state. This is to allow the state to be + portable across different teams and runtimes. States saved with the old format + may not be compatible with the new format in the future. + + .. caution:: + + When calling :func:`~agentdhal_agentchat.teams.BaseGroupChat.save_state` on a team + while it is running, the state may not be consistent and may result in an unexpected state. + It is recommended to call this method when the team is not running or after it is stopped. + + """ + if not self._initialized: + await self._init(self._runtime) + + # Store state of each agent by their name. + # NOTE: we don't use the agent ID as the key here because we need to be able to decouple + # the state of the agents from their identities in the agent runtime. + agent_states: Dict[str, Mapping[str, Any]] = {} + # Save the state of all participants. + for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True): + agent_id = AgentId(type=agent_type, key=self._team_id) + # NOTE: We are using the runtime's save state method rather than the agent instance's + # save_state method because we want to support saving state of remote agents. + agent_states[name] = await self._runtime.agent_save_state(agent_id) + # Save the state of the group chat manager. + agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id) + agent_states[self._group_chat_manager_name] = await self._runtime.agent_save_state(agent_id) + return TeamState(agent_states=agent_states).model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + """Load an external state and overwrite the current state of the group chat team. + + The state is loaded by calling the :meth:`~agentdhal_core.AgentRuntime.agent_load_state` method + on each participant and the group chat manager with their internal agent ID. + See :meth:`~agentdhal_agentchat.teams.BaseGroupChat.save_state` for the expected format of the state. + """ + if not self._initialized: + await self._init(self._runtime) + + if self._is_running: + raise RuntimeError("The team cannot be loaded while it is running.") + self._is_running = True + + try: + team_state = TeamState.model_validate(state) + # Load the state of all participants. + for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True): + agent_id = AgentId(type=agent_type, key=self._team_id) + if name not in team_state.agent_states: + raise ValueError(f"Agent state for {name} not found in the saved state.") + await self._runtime.agent_load_state(agent_id, team_state.agent_states[name]) + # Load the state of the group chat manager. + agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id) + if self._group_chat_manager_name not in team_state.agent_states: + raise ValueError(f"Agent state for {self._group_chat_manager_name} not found in the saved state.") + await self._runtime.agent_load_state(agent_id, team_state.agent_states[self._group_chat_manager_name]) + + except ValidationError as e: + raise ValueError( + "Invalid state format. The expected state format has changed since v0.4.9. " + "Please read the release note on GitHub." + ) from e + + finally: + # Indicate that the team is no longer running. + self._is_running = False diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_base_group_chat_manager.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_base_group_chat_manager.py new file mode 100644 index 0000000..eaf3e73 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -0,0 +1,326 @@ +import asyncio +from abc import ABC, abstractmethod +from typing import Any, List, Sequence + +from agentdhal_core import CancellationToken, DefaultTopicId, MessageContext, event, rpc + +from ...base import TerminationCondition +from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage +from ._events import ( + GroupChatAgentResponse, + GroupChatError, + GroupChatMessage, + GroupChatPause, + GroupChatRequestPublish, + GroupChatReset, + GroupChatResume, + GroupChatStart, + GroupChatTeamResponse, + GroupChatTermination, + SerializableException, +) +from ._sequential_routed_agent import SequentialRoutedAgent + + +class BaseGroupChatManager(SequentialRoutedAgent, ABC): + """Base class for a group chat manager that manages a group chat with multiple participants. + + It is the responsibility of the caller to ensure: + - All participants must subscribe to the group chat topic and each of their own topics. + - The group chat manager must subscribe to the group chat topic. + - The agent types of the participants must be unique. + - For each participant, the agent type must be the same as the topic type. + + Without the above conditions, the group chat will not function correctly. + """ + + def __init__( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + max_turns: int | None, + message_factory: MessageFactory, + emit_team_events: bool = False, + ): + super().__init__( + description="Group chat manager", + sequential_message_types=[ + GroupChatStart, + GroupChatAgentResponse, + GroupChatTeamResponse, + GroupChatMessage, + GroupChatReset, + ], + ) + if max_turns is not None and max_turns <= 0: + raise ValueError("The maximum number of turns must be greater than 0.") + if len(participant_topic_types) != len(participant_descriptions): + raise ValueError("The number of participant topic types, agent types, and descriptions must be the same.") + if len(set(participant_topic_types)) != len(participant_topic_types): + raise ValueError("The participant topic types must be unique.") + if group_topic_type in participant_topic_types: + raise ValueError("The group topic type must not be in the participant topic types.") + self._name = name + self._group_topic_type = group_topic_type + self._output_topic_type = output_topic_type + self._participant_names = participant_names + self._participant_name_to_topic_type = { + name: topic_type for name, topic_type in zip(participant_names, participant_topic_types, strict=True) + } + self._participant_descriptions = participant_descriptions + self._message_thread: List[BaseAgentEvent | BaseChatMessage] = [] + self._output_message_queue = output_message_queue + self._termination_condition = termination_condition + self._max_turns = max_turns + self._current_turn = 0 + self._message_factory = message_factory + self._emit_team_events = emit_team_events + self._active_speakers: List[str] = [] + + @rpc + async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: + """Handle the start of a group chat by selecting a speaker to start the conversation.""" + + # Check if the conversation has already terminated. + if self._termination_condition is not None and self._termination_condition.terminated: + early_stop_message = StopMessage( + content="The group chat has already terminated.", + source=self._name, + ) + # Signal termination to the caller of the team. + await self._signal_termination(early_stop_message) + # Stop the group chat. + return + + # Validate the group state given the start messages + await self.validate_group_state(message.messages) + + if message.messages is not None: + # Log all messages at once + await self.publish_message( + GroupChatStart(messages=message.messages), + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + + # Only put messages in output queue if output_task_messages is True + if message.output_task_messages: + for msg in message.messages: + await self._output_message_queue.put(msg) + + # Relay all messages at once to participants + await self.publish_message( + GroupChatStart(messages=message.messages), + topic_id=DefaultTopicId(type=self._group_topic_type), + cancellation_token=ctx.cancellation_token, + ) + + # Append all messages to thread + await self.update_message_thread(message.messages) + + # Check termination condition after processing all messages + if await self._apply_termination_condition(message.messages): + # Stop the group chat. + return + + # Select speakers to start/continue the conversation + await self._transition_to_next_speakers(ctx.cancellation_token) + + @event + async def handle_agent_response( + self, message: GroupChatAgentResponse | GroupChatTeamResponse, ctx: MessageContext + ) -> None: + try: + # Construct the detla from the agent response. + delta: List[BaseAgentEvent | BaseChatMessage] = [] + if isinstance(message, GroupChatAgentResponse): + if message.response.inner_messages is not None: + for inner_message in message.response.inner_messages: + delta.append(inner_message) + delta.append(message.response.chat_message) + else: + delta.extend(message.result.messages) + + # Append the messages to the message thread. + await self.update_message_thread(delta) + + # Remove the agent from the active speakers list. + self._active_speakers.remove(message.name) + if len(self._active_speakers) > 0: + # If there are still active speakers, return without doing anything. + return + + # Check if the conversation should be terminated. + if await self._apply_termination_condition(delta, increment_turn_count=True): + # Stop the group chat. + return + + # Select speakers to continue the conversation. + await self._transition_to_next_speakers(ctx.cancellation_token) + except Exception as e: + # Handle the exception and signal termination with an error. + error = SerializableException.from_exception(e) + await self._signal_termination_with_error(error) + # Raise the exception to the runtime. + raise + + async def _transition_to_next_speakers(self, cancellation_token: CancellationToken) -> None: + speaker_names_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) + # Link the select speaker future to the cancellation token. + cancellation_token.link_future(speaker_names_future) + speaker_names = await speaker_names_future + if isinstance(speaker_names, str): + # If only one speaker is selected, convert it to a list. + speaker_names = [speaker_names] + for speaker_name in speaker_names: + if speaker_name not in self._participant_name_to_topic_type: + raise RuntimeError(f"Speaker {speaker_name} not found in participant names.") + await self._log_speaker_selection(speaker_names) + + # Send request to publish message to the next speakers + for speaker_name in speaker_names: + speaker_topic_type = self._participant_name_to_topic_type[speaker_name] + await self.publish_message( + GroupChatRequestPublish(), + topic_id=DefaultTopicId(type=speaker_topic_type), + cancellation_token=cancellation_token, + ) + self._active_speakers.append(speaker_name) + + async def _apply_termination_condition( + self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False + ) -> bool: + """Apply the termination condition to the delta and return True if the conversation should be terminated. + It also resets the termination condition and turn count, and signals termination to the caller of the team.""" + if self._termination_condition is not None: + stop_message = await self._termination_condition(delta) + if stop_message is not None: + # Reset the termination conditions and turn count. + await self._termination_condition.reset() + self._current_turn = 0 + # Signal termination to the caller of the team. + await self._signal_termination(stop_message) + # Stop the group chat. + return True + if increment_turn_count: + # Increment the turn count. + self._current_turn += 1 + # Check if the maximum number of turns has been reached. + if self._max_turns is not None: + if self._current_turn >= self._max_turns: + stop_message = StopMessage( + content=f"Maximum number of turns {self._max_turns} reached.", + source=self._name, + ) + # Reset the termination conditions and turn count. + if self._termination_condition is not None: + await self._termination_condition.reset() + self._current_turn = 0 + # Signal termination to the caller of the team. + await self._signal_termination(stop_message) + # Stop the group chat. + return True + return False + + async def _log_speaker_selection(self, speaker_names: List[str]) -> None: + """Log the selected speaker to the output message queue.""" + select_msg = SelectSpeakerEvent(content=speaker_names, source=self._name) + if self._emit_team_events: + await self.publish_message( + GroupChatMessage(message=select_msg), + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + await self._output_message_queue.put(select_msg) + + async def _signal_termination(self, message: StopMessage) -> None: + termination_event = GroupChatTermination(message=message) + # Log the early stop message. + await self.publish_message( + termination_event, + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + # Put the termination event in the output message queue. + await self._output_message_queue.put(termination_event) + + async def _signal_termination_with_error(self, error: SerializableException) -> None: + termination_event = GroupChatTermination( + message=StopMessage(content="An error occurred in the group chat.", source=self._name), error=error + ) + # Log the termination event. + await self.publish_message( + termination_event, + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + # Put the termination event in the output message queue. + await self._output_message_queue.put(termination_event) + + @event + async def handle_group_chat_message(self, message: GroupChatMessage, ctx: MessageContext) -> None: + """Handle a group chat message by appending the content to its output message queue.""" + await self._output_message_queue.put(message.message) + + @event + async def handle_group_chat_error(self, message: GroupChatError, ctx: MessageContext) -> None: + """Handle a group chat error by logging the error and signaling termination.""" + await self._signal_termination_with_error(message.error) + + @rpc + async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None: + """Reset the group chat manager. Calling :meth:`reset` to reset the group chat manager + and clear the message thread.""" + await self.reset() + + @rpc + async def handle_pause(self, message: GroupChatPause, ctx: MessageContext) -> None: + """Pause the group chat manager. This is a no-op in the base class.""" + pass + + @rpc + async def handle_resume(self, message: GroupChatResume, ctx: MessageContext) -> None: + """Resume the group chat manager. This is a no-op in the base class.""" + pass + + @abstractmethod + async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: + """Validate the state of the group chat given the start messages. + This is executed when the group chat manager receives a GroupChatStart event. + + Args: + messages: A list of chat messages to validate, or None if no messages are provided. + """ + ... + + async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + """Update the message thread with the new messages. + This is called when the group chat receives a GroupChatStart or GroupChatAgentResponse event, + before calling the select_speakers method. + """ + self._message_thread.extend(messages) + + @abstractmethod + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: + """Select speakers from the participants and return the topic types of the selected speaker. + This is called when the group chat manager have received all responses from the participants + for a turn and is ready to select the next speakers for the next turn. + + Args: + thread: The message thread of the group chat. + + Returns: + A list of topic types of the selected speakers. + If only one speaker is selected, a single string is returned instead of a list. + """ + ... + + @abstractmethod + async def reset(self) -> None: + """Reset the group chat manager.""" + ... + + async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: + raise ValueError(f"Unhandled message in group chat manager: {type(message)}") diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_chat_agent_container.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_chat_agent_container.py new file mode 100644 index 0000000..bb37df0 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_chat_agent_container.py @@ -0,0 +1,213 @@ +from typing import Any, List, Mapping + +from agentdhal_core import DefaultTopicId, MessageContext, event, rpc, trace_invoke_agent_span + +from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, MessageFactory + +from ...base import ChatAgent, Response, TaskResult, Team +from ...state import ChatAgentContainerState +from ._events import ( + GroupChatAgentResponse, + GroupChatError, + GroupChatMessage, + GroupChatPause, + GroupChatRequestPublish, + GroupChatReset, + GroupChatResume, + GroupChatStart, + GroupChatTeamResponse, + SerializableException, +) +from ._sequential_routed_agent import SequentialRoutedAgent + + +class ChatAgentContainer(SequentialRoutedAgent): + """A core agent class that delegates message handling to an + :class:`agentdhal_agentchat.base.ChatAgent` or :class:`agentdhal_agentchat.base.Team` + so that it can be used in a group chat team. + + Args: + parent_topic_type (str): The topic type of the parent orchestrator. + output_topic_type (str): The topic type for the output. + agent (ChatAgent | Team): The agent or team to delegate message handling to. + message_factory (MessageFactory): The message factory to use for + creating messages from JSON data. + """ + + def __init__( + self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent | Team, message_factory: MessageFactory + ) -> None: + super().__init__( + description=agent.description, + sequential_message_types=[ + GroupChatStart, + GroupChatRequestPublish, + GroupChatReset, + GroupChatAgentResponse, + GroupChatTeamResponse, + ], + ) + self._parent_topic_type = parent_topic_type + self._output_topic_type = output_topic_type + self._agent = agent + self._message_buffer: List[BaseChatMessage] = [] + self._message_factory = message_factory + + @event + async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: + """Handle a start event by appending the content to the buffer.""" + if message.messages is not None: + for msg in message.messages: + self._buffer_message(msg) + + @event + async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: + """Handle an agent response event by appending the content to the buffer.""" + self._buffer_message(message.response.chat_message) + + @event + async def handle_team_response(self, message: GroupChatTeamResponse, ctx: MessageContext) -> None: + """Handle a team response event by appending the content to the buffer.""" + for msg in message.result.messages: + if isinstance(msg, BaseChatMessage): + self._buffer_message(msg) + + @rpc + async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None: + """Handle a reset event by resetting the agent.""" + self._message_buffer.clear() + if isinstance(self._agent, Team): + # If the agent is a team, reset the team. + await self._agent.reset() + else: + await self._agent.on_reset(ctx.cancellation_token) + + @event + async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageContext) -> None: + """Handle a content request event by passing the messages in the buffer + to the delegate agent and publish the response.""" + if isinstance(self._agent, Team): + try: + stream = self._agent.run_stream( + task=self._message_buffer, + cancellation_token=ctx.cancellation_token, + output_task_messages=False, + ) + result: TaskResult | None = None + async for team_event in stream: + if isinstance(team_event, TaskResult): + result = team_event + else: + await self._log_message(team_event) + if result is None: + raise RuntimeError( + "The team did not produce a final TaskResult. Check the team's run_stream method." + ) + self._message_buffer.clear() + # Publish the team response to the group chat. + await self.publish_message( + GroupChatTeamResponse(result=result, name=self._agent.name), + topic_id=DefaultTopicId(type=self._parent_topic_type), + cancellation_token=ctx.cancellation_token, + ) + except Exception as e: + # Publish the error to the group chat. + error_message = SerializableException.from_exception(e) + await self.publish_message( + GroupChatError(error=error_message), + topic_id=DefaultTopicId(type=self._parent_topic_type), + cancellation_token=ctx.cancellation_token, + ) + # Raise the error to the runtime. + raise + else: + # If the agent is not a team, handle it as a single agent. + with trace_invoke_agent_span( + agent_name=self._agent.name, + agent_description=self._agent.description, + agent_id=str(self.id), + ): + try: + # Pass the messages in the buffer to the delegate agent. + response: Response | None = None + async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token): + if isinstance(msg, Response): + await self._log_message(msg.chat_message) + response = msg + else: + await self._log_message(msg) + if response is None: + raise RuntimeError( + "The agent did not produce a final response. Check the agent's on_messages_stream method." + ) + # Publish the response to the group chat. + self._message_buffer.clear() + await self.publish_message( + GroupChatAgentResponse(response=response, name=self._agent.name), + topic_id=DefaultTopicId(type=self._parent_topic_type), + cancellation_token=ctx.cancellation_token, + ) + except Exception as e: + # Publish the error to the group chat. + error_message = SerializableException.from_exception(e) + await self.publish_message( + GroupChatError(error=error_message), + topic_id=DefaultTopicId(type=self._parent_topic_type), + cancellation_token=ctx.cancellation_token, + ) + # Raise the error to the runtime. + raise + + def _buffer_message(self, message: BaseChatMessage) -> None: + if not self._message_factory.is_registered(message.__class__): + raise ValueError(f"Message type {message.__class__} is not registered.") + # Buffer the message. + self._message_buffer.append(message) + + async def _log_message(self, message: BaseAgentEvent | BaseChatMessage) -> None: + if not self._message_factory.is_registered(message.__class__): + raise ValueError(f"Message type {message.__class__} is not registered.") + # Log the message. + await self.publish_message( + GroupChatMessage(message=message), + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + + @rpc + async def handle_pause(self, message: GroupChatPause, ctx: MessageContext) -> None: + """Handle a pause event by pausing the agent.""" + if isinstance(self._agent, Team): + # If the agent is a team, pause the team. + await self._agent.pause() + else: + await self._agent.on_pause(ctx.cancellation_token) + + @rpc + async def handle_resume(self, message: GroupChatResume, ctx: MessageContext) -> None: + """Handle a resume event by resuming the agent.""" + if isinstance(self._agent, Team): + # If the agent is a team, resume the team. + await self._agent.resume() + else: + await self._agent.on_resume(ctx.cancellation_token) + + async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: + raise ValueError(f"Unhandled message in agent container: {type(message)}") + + async def save_state(self) -> Mapping[str, Any]: + agent_state = await self._agent.save_state() + state = ChatAgentContainerState( + agent_state=agent_state, message_buffer=[message.dump() for message in self._message_buffer] + ) + return state.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + container_state = ChatAgentContainerState.model_validate(state) + self._message_buffer = [] + for message_data in container_state.message_buffer: + message = self._message_factory.create(message_data) + if isinstance(message, BaseChatMessage): + self._message_buffer.append(message) + else: + raise ValueError(f"Invalid message type in message buffer: {type(message)}") + await self._agent.load_state(container_state.agent_state) diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_events.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_events.py new file mode 100644 index 0000000..a149e58 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_events.py @@ -0,0 +1,113 @@ +import traceback +from typing import List + +from pydantic import BaseModel, SerializeAsAny + +from ...base import Response, TaskResult +from ...messages import BaseAgentEvent, BaseChatMessage, StopMessage + + +class SerializableException(BaseModel): + """A serializable exception.""" + + error_type: str + """The type of error that occurred.""" + + error_message: str + """The error message that describes the error.""" + + traceback: str | None = None + """The traceback of the error, if available.""" + + @classmethod + def from_exception(cls, exc: Exception) -> "SerializableException": + """Create a GroupChatError from an exception.""" + return cls( + error_type=type(exc).__name__, + error_message=str(exc), + traceback="\n".join(traceback.format_exception(type(exc), exc, exc.__traceback__)), + ) + + def __str__(self) -> str: + """Return a string representation of the error, including the traceback if available.""" + if self.traceback: + return f"{self.error_type}: {self.error_message}\nTraceback:\n{self.traceback}" + return f"{self.error_type}: {self.error_message}" + + +class GroupChatStart(BaseModel): + """A request to start a group chat.""" + + messages: List[SerializeAsAny[BaseChatMessage]] | None = None + """An optional list of messages to start the group chat.""" + + output_task_messages: bool = True + """Whether to include task messages in the output. Defaults to True for backward compatibility.""" + + +class GroupChatAgentResponse(BaseModel): + """A response published to a group chat.""" + + response: SerializeAsAny[Response] + """The response from an agent.""" + + name: str + """The name of the agent that produced the response.""" + + +class GroupChatTeamResponse(BaseModel): + """A response published to a group chat from a team.""" + + result: SerializeAsAny[TaskResult] + """The result from a team.""" + + name: str + """The name of the team that produced the response.""" + + +class GroupChatRequestPublish(BaseModel): + """A request to publish a message to a group chat.""" + + ... + + +class GroupChatMessage(BaseModel): + """A message from a group chat.""" + + message: SerializeAsAny[BaseAgentEvent | BaseChatMessage] + """The message that was published.""" + + +class GroupChatTermination(BaseModel): + """A message indicating that a group chat has terminated.""" + + message: StopMessage + """The stop message that indicates the reason of termination.""" + + error: SerializableException | None = None + """The error that occurred, if any.""" + + +class GroupChatReset(BaseModel): + """A request to reset the agents in the group chat.""" + + ... + + +class GroupChatPause(BaseModel): + """A request to pause the group chat.""" + + ... + + +class GroupChatResume(BaseModel): + """A request to resume the group chat.""" + + ... + + +class GroupChatError(BaseModel): + """A message indicating that an error occurred in the group chat.""" + + error: SerializableException + """The error that occurred.""" diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_graph/__init__.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_graph/__init__.py new file mode 100644 index 0000000..f38d6d6 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_graph/__init__.py @@ -0,0 +1,17 @@ +from ._digraph_group_chat import ( + DiGraph, + DiGraphEdge, + DiGraphNode, + GraphFlow, + GraphFlowManager, +) +from ._graph_builder import DiGraphBuilder + +__all__ = [ + "GraphFlow", + "DiGraph", + "GraphFlowManager", + "DiGraphNode", + "DiGraphEdge", + "DiGraphBuilder", +] diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py new file mode 100644 index 0000000..517336f --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py @@ -0,0 +1,877 @@ +import asyncio +from collections import Counter, deque +from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union + +from agentdhal_core import AgentRuntime, Component, ComponentModel +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self + +from agentdhal_agentchat.base import ChatAgent, TerminationCondition +from agentdhal_agentchat.messages import ( + BaseAgentEvent, + BaseChatMessage, + MessageFactory, + StopMessage, +) +from agentdhal_agentchat.state import BaseGroupChatManagerState +from agentdhal_agentchat.teams import BaseGroupChat + +from ..._group_chat._base_group_chat_manager import BaseGroupChatManager +from ..._group_chat._events import GroupChatTermination + +_DIGRAPH_STOP_MESSAGE = "Digraph execution is complete" + + +class DiGraphEdge(BaseModel): + """Represents a directed edge in a :class:`DiGraph`, with an optional execution condition. + + .. warning:: + + This is an experimental feature, and the API will change in the future releases. + + .. warning:: + + If the condition is a callable, it will not be serialized in the model. + + """ + + target: str # Target node name + condition: Union[str, Callable[[BaseChatMessage], bool], None] = Field(default=None) + """(Experimental) Condition to execute this edge. + If None, the edge is unconditional. + If a string, the edge is conditional on the presence of that string in the last agent chat message. + If a callable, the edge is conditional on the callable returning True when given the last message. + """ + + # Using Field to exclude the condition in serialization if it's a callable + condition_function: Callable[[BaseChatMessage], bool] | None = Field(default=None, exclude=True) + activation_group: str = Field(default="") + """Group identifier for forward dependencies. + + When multiple edges point to the same target node, they are grouped by this field. + This allows distinguishing between different cycles or dependency patterns. + + Example: In a graph containing a cycle like A->B->C->B, the two edges pointing to B (A->B and C->B) + can be in different activation groups to control how B is activated. + Defaults to the target node name if not specified. + """ + activation_condition: Literal["all", "any"] = "all" + """Determines how forward dependencies within the same activation_group are evaluated. + + - "all": All edges in this activation group must be satisfied before the target node can execute + - "any": Any single edge in this activation group being satisfied allows the target node to execute + + This is used to handle complex dependency patterns in cyclic graphs where multiple + paths can lead to the same target node. + """ + + @model_validator(mode="after") + def _validate_condition(self) -> "DiGraphEdge": + # Store callable in a separate field and set condition to None for serialization + if callable(self.condition): + self.condition_function = self.condition + # For serialization purposes, we'll set the condition to None + # when storing as a pydantic model/dict + object.__setattr__(self, "condition", None) + + # Set activation_group to target if not already set + if not self.activation_group: + self.activation_group = self.target + + return self + + def check_condition(self, message: BaseChatMessage) -> bool: + """Check if the edge condition is satisfied for the given message. + + Args: + message: The message to check the condition against. + + Returns: + True if condition is satisfied (None condition always returns True), + False otherwise. + """ + if self.condition_function is not None: + return self.condition_function(message) + elif isinstance(self.condition, str): + # If it's a string, check if the string is in the message content + return self.condition in message.to_model_text() + return True # None condition is always satisfied + + +class DiGraphNode(BaseModel): + """Represents a node (agent) in a :class:`DiGraph`, with its outgoing edges and activation type. + + .. warning:: + + This is an experimental feature, and the API will change in the future releases. + + """ + + name: str # Agent's name + edges: List[DiGraphEdge] = [] # Outgoing edges + activation: Literal["all", "any"] = "all" + + +class DiGraph(BaseModel): + """Defines a directed graph structure with nodes and edges. + :class:`GraphFlow` uses this to determine execution order and conditions. + + .. warning:: + + This is an experimental feature, and the API will change in the future releases. + + """ + + nodes: Dict[str, DiGraphNode] # Node name → DiGraphNode mapping + default_start_node: str | None = None # Default start node name + _has_cycles: bool | None = None # Cyclic graph flag + + def get_parents(self) -> Dict[str, List[str]]: + """Compute a mapping of each node to its parent nodes.""" + parents: Dict[str, List[str]] = {node: [] for node in self.nodes} + for node in self.nodes.values(): + for edge in node.edges: + parents[edge.target].append(node.name) + return parents + + def get_start_nodes(self) -> Set[str]: + """Return the nodes that have no incoming edges (entry points).""" + if self.default_start_node: + return {self.default_start_node} + + parents = self.get_parents() + return set([node_name for node_name, parent_list in parents.items() if not parent_list]) + + def get_leaf_nodes(self) -> Set[str]: + """Return nodes that have no outgoing edges (final output nodes).""" + return set([name for name, node in self.nodes.items() if not node.edges]) + + def has_cycles_with_exit(self) -> bool: + """ + Check if the graph has any cycles and validate that each cycle has at least one conditional edge. + + Returns: + bool: True if there is at least one cycle and all cycles have an exit condition. + False if there are no cycles. + + Raises: + ValueError: If there is a cycle without any conditional edge. + """ + visited: Set[str] = set() + rec_stack: Set[str] = set() + path: List[str] = [] + + def dfs(node_name: str) -> bool: + visited.add(node_name) + rec_stack.add(node_name) + path.append(node_name) + + for edge in self.nodes[node_name].edges: + target = edge.target + if target not in visited: + if dfs(target): + return True + elif target in rec_stack: + # Found a cycle → extract the cycle + cycle_start_index = path.index(target) + cycle_nodes = path[cycle_start_index:] + cycle_edges: List[DiGraphEdge] = [] + for n in cycle_nodes: + cycle_edges.extend(self.nodes[n].edges) + if all(edge.condition is None and edge.condition_function is None for edge in cycle_edges): + raise ValueError( + f"Cycle detected without exit condition: {' -> '.join(cycle_nodes + cycle_nodes[:1])}" + ) + return True # Found cycle, but it has an exit condition + + rec_stack.remove(node_name) + path.pop() + return False + + has_cycle = False + for node in self.nodes: + if node not in visited: + if dfs(node): + has_cycle = True + + return has_cycle + + def get_has_cycles(self) -> bool: + """Indicates if the graph has at least one cycle (with valid exit conditions).""" + if self._has_cycles is None: + self._has_cycles = self.has_cycles_with_exit() + + return self._has_cycles + + def graph_validate(self) -> None: + """Validate graph structure and execution rules.""" + if not self.nodes: + raise ValueError("Graph has no nodes.") + + if not self.get_start_nodes(): + raise ValueError("Graph must have at least one start node") + + if not self.get_leaf_nodes(): + raise ValueError("Graph must have at least one leaf node") + + # Outgoing edge condition validation (per node) + for node in self.nodes.values(): + # Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional + has_condition = any( + edge.condition is not None or edge.condition_function is not None for edge in node.edges + ) + has_unconditioned = any(edge.condition is None and edge.condition_function is None for edge in node.edges) + if has_condition and has_unconditioned: + raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.") + + # Validate activation conditions across all edges in the graph + self._validate_activation_conditions() + + self._has_cycles = self.has_cycles_with_exit() + + def _validate_activation_conditions(self) -> None: + """Validate that all edges pointing to the same target node have consistent activation_condition values. + + Raises: + ValueError: If edges pointing to the same target have different activation_condition values + """ + target_activation_conditions: Dict[str, Dict[str, str]] = {} # target_node -> {activation_group -> condition} + + for node in self.nodes.values(): + for edge in node.edges: + target = edge.target # The target node this edge points to + activation_group = edge.activation_group + + if target not in target_activation_conditions: + target_activation_conditions[target] = {} + + if activation_group in target_activation_conditions[target]: + if target_activation_conditions[target][activation_group] != edge.activation_condition: + # Find the source node that has the conflicting condition + conflicting_source = self._find_edge_source_by_target_and_group( + target, activation_group, target_activation_conditions[target][activation_group] + ) + raise ValueError( + f"Conflicting activation conditions for target '{target}' group '{activation_group}': " + f"'{target_activation_conditions[target][activation_group]}' (from node '{conflicting_source}') " + f"and '{edge.activation_condition}' (from node '{node.name}')" + ) + else: + target_activation_conditions[target][activation_group] = edge.activation_condition + + def _find_edge_source_by_target_and_group( + self, target: str, activation_group: str, activation_condition: str + ) -> str: + """Find the source node that has an edge pointing to the given target with the given activation_group and activation_condition.""" + for node_name, node in self.nodes.items(): + for edge in node.edges: + if ( + edge.target == target + and edge.activation_group == activation_group + and edge.activation_condition == activation_condition + ): + return node_name + return "unknown" + + def get_remaining_map(self) -> Dict[str, Dict[str, int]]: + """Get the remaining map that tracks how many edges point to each target node with each activation group. + + Returns: + Dictionary mapping target nodes to their activation groups and remaining counts + """ + + remaining_map: Dict[str, Dict[str, int]] = {} + + for node in self.nodes.values(): + for edge in node.edges: + target = edge.target + activation_group = edge.activation_group + + if target not in remaining_map: + remaining_map[target] = {} + + if activation_group not in remaining_map[target]: + remaining_map[target][activation_group] = 0 + + remaining_map[target][activation_group] += 1 + + return remaining_map + + +class GraphFlowManagerState(BaseGroupChatManagerState): + """Tracks active execution state for DAG-based execution.""" + + active_nodes: List[str] = [] # Currently executing nodes + type: str = "GraphManagerState" + + +class GraphFlowManager(BaseGroupChatManager): + """Manages execution of agents using a Directed Graph execution model.""" + + def __init__( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + max_turns: int | None, + message_factory: MessageFactory, + graph: DiGraph, + ) -> None: + """Initialize the graph-based execution manager.""" + super().__init__( + name=name, + group_topic_type=group_topic_type, + output_topic_type=output_topic_type, + participant_topic_types=participant_topic_types, + participant_names=participant_names, + participant_descriptions=participant_descriptions, + output_message_queue=output_message_queue, + termination_condition=termination_condition, + max_turns=max_turns, + message_factory=message_factory, + ) + graph.graph_validate() + if graph.get_has_cycles() and self._termination_condition is None and self._max_turns is None: + raise ValueError("A termination condition is required for cyclic graphs without a maximum turn limit.") + self._graph = graph + # Lookup table for incoming edges for each node. + self._parents = graph.get_parents() + # Lookup table for outgoing edges for each node. + self._edges: Dict[str, List[DiGraphEdge]] = {n: node.edges for n, node in graph.nodes.items()} + + # Build activation and enqueued_any lookup tables by collecting all edges and grouping by target node + self._build_lookup_tables(graph) + + # Track which activation groups were triggered for each node + self._triggered_activation_groups: Dict[str, Set[str]] = {} + # === Mutable states for the graph execution === + # Count the number of remaining parents to activate each node. + self._remaining: Dict[str, Counter[str]] = { + target: Counter(groups) for target, groups in graph.get_remaining_map().items() + } + # cache for remaining + self._origin_remaining: Dict[str, Dict[str, int]] = { + target: Counter(groups) for target, groups in self._remaining.items() + } + + # Ready queue for nodes that are ready to execute, starting with the start nodes. + self._ready: Deque[str] = deque([n for n in graph.get_start_nodes()]) + + def _build_lookup_tables(self, graph: DiGraph) -> None: + """Build activation and enqueued_any lookup tables by collecting all edges and grouping by target node. + + Args: + graph: The directed graph + """ + self._activation: Dict[str, Dict[str, Literal["any", "all"]]] = {} + self._enqueued_any: Dict[str, Dict[str, bool]] = {} + + for node in graph.nodes.values(): + for edge in node.edges: + target = edge.target + activation_group = edge.activation_group + + # Build activation lookup + if target not in self._activation: + self._activation[target] = {} + if activation_group not in self._activation[target]: + self._activation[target][activation_group] = edge.activation_condition + + # Build enqueued_any lookup + if target not in self._enqueued_any: + self._enqueued_any[target] = {} + if activation_group not in self._enqueued_any[target]: + self._enqueued_any[target][activation_group] = False + + async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + await super().update_message_thread(messages) + + # Find the node that ran in the current turn. + message = messages[-1] + if message.source not in self._graph.nodes: + # Ignore messages from sources outside of the graph. + return + assert isinstance(message, BaseChatMessage) + source = message.source + + # Propagate the update to the children of the node. + for edge in self._edges[source]: + # Use the new check_condition method that handles both string and callable conditions + if not edge.check_condition(message): + continue + + target = edge.target + activation_group = edge.activation_group + + if self._activation[target][activation_group] == "all": + self._remaining[target][activation_group] -= 1 + if self._remaining[target][activation_group] == 0: + # If all parents are done, add to the ready queue. + self._ready.append(target) + # Track which activation group was triggered + self._save_triggered_activation_group(target, activation_group) + else: + # If activation is any, add to the ready queue if not already enqueued. + if not self._enqueued_any[target][activation_group]: + self._ready.append(target) + self._enqueued_any[target][activation_group] = True + # Track which activation group was triggered + self._save_triggered_activation_group(target, activation_group) + + def _save_triggered_activation_group(self, target: str, activation_group: str) -> None: + """Save which activation group was triggered for a target node. + + Args: + target: The target node that was triggered + activation_group: The activation group that caused the trigger + """ + if target not in self._triggered_activation_groups: + self._triggered_activation_groups[target] = set() + self._triggered_activation_groups[target].add(activation_group) + + def _reset_triggered_activation_groups(self, speaker: str) -> None: + """Reset the bookkeeping for the specific activation groups that were triggered for a speaker. + + Args: + speaker: The speaker node to reset activation groups for + """ + if speaker not in self._triggered_activation_groups: + return + + for activation_group in self._triggered_activation_groups[speaker]: + if self._activation[speaker][activation_group] == "any": + self._enqueued_any[speaker][activation_group] = False + else: + # Reset the remaining count for this activation group using the graph's original count + if speaker in self._remaining and activation_group in self._remaining[speaker]: + self._remaining[speaker][activation_group] = self._origin_remaining[speaker][activation_group] + + # Clear the triggered activation groups for this speaker + self._triggered_activation_groups[speaker].clear() + + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]: + # Drain the ready queue for the next set of speakers. + speakers: List[str] = [] + while self._ready: + speaker = self._ready.popleft() + speakers.append(speaker) + + # Reset the bookkeeping for the specific activation groups that were triggered + self._reset_triggered_activation_groups(speaker) + + return speakers + + async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: + pass + + async def _apply_termination_condition( + self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False + ) -> bool: + """Apply termination condition including graph-specific completion logic. + + First checks if graph execution is complete, then checks standard termination conditions. + + Args: + delta: The message delta to check termination conditions against + increment_turn_count: Whether to increment the turn count + + Returns: + True if the conversation should be terminated, False otherwise + """ + # Check if the graph execution is complete (no ready speakers) - prioritize this check + if not self._ready: + stop_message = StopMessage( + content=_DIGRAPH_STOP_MESSAGE, + source=self._name, + ) + # Reset the execution state when the graph has naturally completed + self._reset_execution_state() + # Reset the termination conditions and turn count. + if self._termination_condition is not None: + await self._termination_condition.reset() + self._current_turn = 0 + # Signal termination to the caller of the team. + await self._signal_termination(stop_message) + return True + + # Apply the standard termination conditions from the base class + return await super()._apply_termination_condition(delta, increment_turn_count) + + def _reset_execution_state(self) -> None: + """Reset the graph execution state to the initial state.""" + self._remaining = {target: Counter(groups) for target, groups in self._graph.get_remaining_map().items()} + self._enqueued_any = {n: {g: False for g in self._enqueued_any[n]} for n in self._enqueued_any} + self._ready = deque([n for n in self._graph.get_start_nodes()]) + + async def save_state(self) -> Mapping[str, Any]: + """Save the execution state.""" + state = { + "message_thread": [message.dump() for message in self._message_thread], + "current_turn": self._current_turn, + "remaining": {target: dict(counter) for target, counter in self._remaining.items()}, + "enqueued_any": dict(self._enqueued_any), + "ready": list(self._ready), + } + return state + + async def load_state(self, state: Mapping[str, Any]) -> None: + """Restore execution state from saved data.""" + self._message_thread = [self._message_factory.create(msg) for msg in state["message_thread"]] + self._current_turn = state["current_turn"] + self._remaining = {target: Counter(groups) for target, groups in state["remaining"].items()} + self._enqueued_any = state["enqueued_any"] + self._ready = deque(state["ready"]) + + async def reset(self) -> None: + """Reset execution state to the start of the graph.""" + self._current_turn = 0 + self._message_thread.clear() + if self._termination_condition: + await self._termination_condition.reset() + self._reset_execution_state() + + +class GraphFlowConfig(BaseModel): + """The declarative configuration for GraphFlow.""" + + name: str | None = None + description: str | None = None + participants: List[ComponentModel] + termination_condition: ComponentModel | None = None + max_turns: int | None = None + graph: DiGraph # The execution graph for agents + + +class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]): + """A team that runs a group chat following a Directed Graph execution pattern. + + .. warning:: + + This is an experimental feature, and the API will change in the future releases. + + This group chat executes agents based on a directed graph (:class:`DiGraph`) structure, + allowing complex workflows such as sequential execution, parallel fan-out, + conditional branching, join patterns, and loops with explicit exit conditions. + + The execution order is determined by the edges defined in the `DiGraph`. Each node + in the graph corresponds to an agent, and edges define the flow of messages between agents. + Nodes can be configured to activate when: + + - **All** parent nodes have completed (activation="all") → default + - **Any** parent node completes (activation="any") + + Conditional branching is supported using edge conditions, where the next agent(s) are selected + based on content in the chat history. Loops are permitted as long as there is a condition + that eventually exits the loop. + + .. note:: + + Use the :class:`DiGraphBuilder` class to create a :class:`DiGraph` easily. It provides a fluent API + for adding nodes and edges, setting entry points, and validating the graph structure. + See the :class:`DiGraphBuilder` documentation for more details. + The :class:`GraphFlow` class is designed to be used with the :class:`DiGraphBuilder` for creating complex workflows. + + .. warning:: + + When using callable conditions in edges, they will not be serialized + when calling :meth:`dump_component`. This will be addressed in future releases. + + + Args: + participants (List[ChatAgent]): The participants in the group chat. + termination_condition (TerminationCondition, optional): Termination condition for the chat. + max_turns (int, optional): Maximum number of turns before forcing termination. + graph (DiGraph): Directed execution graph defining node flow and conditions. + + Raises: + ValueError: If participant names are not unique, or if graph validation fails (e.g., cycles without exit). + + Examples: + + **Sequential Flow: A → B → C** + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import MaxMessageTermination + from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main(): + # Initialize agents with OpenAI model clients. + model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") + agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.") + agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.") + agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to English.") + + # Create a directed graph with sequential flow A -> B -> C. + builder = DiGraphBuilder() + builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) + builder.add_edge(agent_a, agent_b).add_edge(agent_b, agent_c) + graph = builder.build() + + # Create a GraphFlow team with the directed graph. + team = GraphFlow( + participants=[agent_a, agent_b, agent_c], + graph=graph, + termination_condition=MaxMessageTermination(5), + ) + + # Run the team and print the events. + async for event in team.run_stream(task="Write a short story about a cat."): + print(event) + + + asyncio.run(main()) + + **Parallel Fan-out: A → (B, C)** + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import MaxMessageTermination + from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main(): + # Initialize agents with OpenAI model clients. + model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") + agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.") + agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.") + agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Japanese.") + + # Create a directed graph with fan-out flow A -> (B, C). + builder = DiGraphBuilder() + builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) + builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c) + graph = builder.build() + + # Create a GraphFlow team with the directed graph. + team = GraphFlow( + participants=[agent_a, agent_b, agent_c], + graph=graph, + termination_condition=MaxMessageTermination(5), + ) + + # Run the team and print the events. + async for event in team.run_stream(task="Write a short story about a cat."): + print(event) + + + asyncio.run(main()) + + **Conditional Branching: A → B (if 'yes') or C (otherwise)** + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import MaxMessageTermination + from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main(): + # Initialize agents with OpenAI model clients. + model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") + agent_a = AssistantAgent( + "A", + model_client=model_client, + system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.", + ) + agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.") + agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.") + + # Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise). + builder = DiGraphBuilder() + builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) + # Create conditions as callables that check the message content. + builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text()) + builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text()) + graph = builder.build() + + # Create a GraphFlow team with the directed graph. + team = GraphFlow( + participants=[agent_a, agent_b, agent_c], + graph=graph, + termination_condition=MaxMessageTermination(5), + ) + + # Run the team and print the events. + async for event in team.run_stream(task="AutoGen is a framework for building AI agents."): + print(event) + + + asyncio.run(main()) + + **Loop with exit condition: A → B → C (if 'APPROVE') or A (otherwise)** + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import MaxMessageTermination + from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main(): + # Initialize agents with OpenAI model clients. + model_client = OpenAIChatCompletionClient(model="gpt-4.1") + agent_a = AssistantAgent( + "A", + model_client=model_client, + system_message="You are a helpful assistant.", + ) + agent_b = AssistantAgent( + "B", + model_client=model_client, + system_message="Provide feedback on the input, if your feedback has been addressed, " + "say 'APPROVE', otherwise provide a reason for rejection.", + ) + agent_c = AssistantAgent( + "C", model_client=model_client, system_message="Translate the final product to Korean." + ) + + # Create a loop graph with conditional exit: A -> B -> C ("APPROVE"), B -> A (otherwise). + builder = DiGraphBuilder() + builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) + builder.add_edge(agent_a, agent_b) + + # Create conditional edges using strings + builder.add_edge(agent_b, agent_c, condition=lambda msg: "APPROVE" in msg.to_model_text()) + builder.add_edge(agent_b, agent_a, condition=lambda msg: "APPROVE" not in msg.to_model_text()) + + builder.set_entry_point(agent_a) + graph = builder.build() + + # Create a GraphFlow team with the directed graph. + team = GraphFlow( + participants=[agent_a, agent_b, agent_c], + graph=graph, + termination_condition=MaxMessageTermination(20), # Max 20 messages to avoid infinite loop. + ) + + # Run the team and print the events. + async for event in team.run_stream(task="Write a short poem about AI Agents."): + print(event) + + + asyncio.run(main()) + """ + + component_config_schema = GraphFlowConfig + component_provider_override = "agentdhal_agentchat.teams.GraphFlow" + + DEFAULT_NAME = "GraphFlow" + DEFAULT_DESCRIPTION = "A team of agents" + + def __init__( + self, + participants: List[ChatAgent], + graph: DiGraph, + *, + name: str | None = None, + description: str | None = None, + termination_condition: TerminationCondition | None = None, + max_turns: int | None = None, + runtime: AgentRuntime | None = None, + custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, + ) -> None: + self._input_participants = participants + self._input_termination_condition = termination_condition + + for participant in participants: + if not isinstance(participant, ChatAgent): + raise TypeError(f"Participant {participant} must be a ChatAgent.") + + # No longer add _StopAgent or StopMessageTermination + # Termination is now handled directly in GraphFlowManager._apply_termination_condition + super().__init__( + name=name or self.DEFAULT_NAME, + description=description or self.DEFAULT_DESCRIPTION, + participants=list(participants), + group_chat_manager_name="GraphManager", + group_chat_manager_class=GraphFlowManager, + termination_condition=termination_condition, + max_turns=max_turns, + runtime=runtime, + custom_message_types=custom_message_types, + ) + self._graph = graph + + def _create_group_chat_manager_factory( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + max_turns: int | None, + message_factory: MessageFactory, + ) -> Callable[[], GraphFlowManager]: + """Creates the factory method for initializing the DiGraph-based chat manager.""" + + def _factory() -> GraphFlowManager: + return GraphFlowManager( + name=name, + group_topic_type=group_topic_type, + output_topic_type=output_topic_type, + participant_topic_types=participant_topic_types, + participant_names=participant_names, + participant_descriptions=participant_descriptions, + output_message_queue=output_message_queue, + termination_condition=termination_condition, + max_turns=max_turns, + message_factory=message_factory, + graph=self._graph, + ) + + return _factory + + def _to_config(self) -> GraphFlowConfig: + """Converts the instance into a configuration object.""" + participants = [participant.dump_component() for participant in self._input_participants] + termination_condition = ( + self._input_termination_condition.dump_component() if self._input_termination_condition else None + ) + return GraphFlowConfig( + name=self._name, + description=self._description, + participants=participants, + termination_condition=termination_condition, + max_turns=self._max_turns, + graph=self._graph, + ) + + @classmethod + def _from_config(cls, config: GraphFlowConfig) -> Self: + """Reconstructs an instance from a configuration object.""" + participants = [ChatAgent.load_component(participant) for participant in config.participants] + termination_condition = ( + TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None + ) + return cls( + name=config.name, + description=config.description, + participants=participants, + graph=config.graph, + termination_condition=termination_condition, + max_turns=config.max_turns, + ) diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_graph/_graph_builder.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_graph/_graph_builder.py new file mode 100644 index 0000000..a6aa4fa --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_graph/_graph_builder.py @@ -0,0 +1,209 @@ +import warnings +from typing import Callable, Dict, Literal, Optional, Union + +from agentdhal_agentchat.base import ChatAgent +from agentdhal_agentchat.messages import BaseChatMessage + +from ._digraph_group_chat import DiGraph, DiGraphEdge, DiGraphNode + + +class DiGraphBuilder: + """ + A fluent builder for constructing :class:`DiGraph` execution graphs used in :class:`GraphFlow`. + + .. warning:: + + This is an experimental feature, and the API will change in the future releases. + + This utility provides a convenient way to programmatically build a graph of agent interactions, + including complex execution flows such as: + + - Sequential chains + - Parallel fan-outs + - Conditional branching + - Cyclic loops with safe exits + + Each node in the graph represents an agent. Edges define execution paths between agents, + and can optionally be conditioned on message content using callable functions. + + The builder is compatible with the `Graph` runner and supports both standard and filtered agents. + + Methods: + - add_node(agent, activation): Add an agent node to the graph. + - add_edge(source, target, condition): Connect two nodes optionally with a condition. + - add_conditional_edges(source, condition_to_target): Add multiple conditional edges from a source. + - set_entry_point(agent): Define the default start node (optional). + - build(): Generate a validated `DiGraph`. + - get_participants(): Return the list of added agents. + + Example — Sequential Flow A → B → C: + >>> builder = GraphBuilder() + >>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) + >>> builder.add_edge(agent_a, agent_b).add_edge(agent_b, agent_c) + >>> team = Graph( + ... participants=builder.get_participants(), + ... graph=builder.build(), + ... termination_condition=MaxMessageTermination(5), + ... ) + + Example — Parallel Fan-out A → (B, C): + >>> builder = GraphBuilder() + >>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) + >>> builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c) + + Example — Conditional Branching A → B or A → C: + >>> builder = GraphBuilder() + >>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) + >>> # Add conditional edges using keyword check + >>> builder.add_edge(agent_a, agent_b, condition="keyword1") + >>> builder.add_edge(agent_a, agent_c, condition="keyword2") + + + Example — Using Custom String Conditions: + >>> builder = GraphBuilder() + >>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) + >>> # Add condition strings to check in messages + >>> builder.add_edge(agent_a, agent_b, condition="big") + >>> builder.add_edge(agent_a, agent_c, condition="small") + + Example — Loop: A → B → A or B → C: + >>> builder = GraphBuilder() + >>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) + >>> builder.add_edge(agent_a, agent_b) + >> # Add a loop back to agent A + >>> builder.add_edge(agent_b, agent_a, condition=lambda msg: "loop" in msg.to_model_text()) + >>> # Add exit condition to break the loop + >>> builder.add_edge(agent_b, agent_c, condition=lambda msg: "loop" not in msg.to_model_text()) + + Example — Loop with multiple paths to the same node: A → B → C → B: + >>> builder = GraphBuilder() + >>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) + >>> builder.add_edge(agent_a, agent_b) + >>> builder.add_edge(agent_b, agent_c) + >>> builder.add_edge(agent_c, agent_b, activation_group="loop_back") + + Example — Loop with multiple paths to the same node with any activation condition: A → B → (C1, C2) → B → E(exit): + >>> builder = GraphBuilder() + >>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c1).add_node(agent_c2).add_node(agent_e) + >>> builder.add_edge(agent_a, agent_b) + >>> builder.add_edge(agent_b, agent_c1) + >>> builder.add_edge(agent_b, agent_c2) + >>> builder.add_edge(agent_b, agent_e, condition="exit") + >>> builder.add_edge(agent_c1, agent_b, activation_group="loop_back_group", activation_condition="any") + >>> builder.add_edge(agent_c2, agent_b, activation_group="loop_back_group", activation_condition="any") + """ + + def __init__(self) -> None: + self.nodes: Dict[str, DiGraphNode] = {} + self.agents: Dict[str, ChatAgent] = {} + self._default_start_node: Optional[str] = None + + def _get_name(self, obj: Union[str, ChatAgent]) -> str: + return obj if isinstance(obj, str) else obj.name + + def add_node(self, agent: ChatAgent, activation: Literal["all", "any"] = "all") -> "DiGraphBuilder": + """Add a node to the graph and register its agent.""" + name = agent.name + if name not in self.nodes: + self.nodes[name] = DiGraphNode(name=name, edges=[], activation=activation) + self.agents[name] = agent + return self + + def add_edge( + self, + source: Union[str, ChatAgent], + target: Union[str, ChatAgent], + condition: Optional[Union[str, Callable[[BaseChatMessage], bool]]] = None, + activation_group: Optional[str] = None, + activation_condition: Optional[Literal["all", "any"]] = None, + ) -> "DiGraphBuilder": + """Add a directed edge from source to target, optionally with a condition. + + Args: + source: Source node (agent name or agent object) + target: Target node (agent name or agent object) + condition: Optional condition for edge activation. + If string, activates when substring is found in message. + If callable, activates when function returns True for the message. + + Returns: + Self for method chaining + + Raises: + ValueError: If source or target node doesn't exist in the builder + """ + source_name = self._get_name(source) + target_name = self._get_name(target) + + if source_name not in self.nodes: + raise ValueError(f"Source node '{source_name}' must be added before adding an edge.") + if target_name not in self.nodes: + raise ValueError(f"Target node '{target_name}' must be added before adding an edge.") + if activation_group is None: + activation_group = target_name + if activation_condition is None: + activation_condition = "all" + self.nodes[source_name].edges.append( + DiGraphEdge( + target=target_name, + condition=condition, + activation_group=activation_group, + activation_condition=activation_condition, + ) + ) + return self + + def add_conditional_edges( + self, source: Union[str, ChatAgent], condition_to_target: Dict[str, Union[str, ChatAgent]] + ) -> "DiGraphBuilder": + """Add multiple conditional edges from a source node based on keyword checks. + + .. warning:: + + This method interface will be changed in the future to support callable conditions. + Please use `add_edge` if you need to specify custom conditions. + + Args: + source: Source node (agent name or agent object) + condition_to_target: Mapping from condition strings to target nodes + Each key is a keyword that will be checked in the message content + Each value is the target node to activate when condition is met + + For each key (keyword), a lambda will be created that checks + if the keyword is in the message text. + + Returns: + Self for method chaining + """ + + warnings.warn( + "add_conditional_edges will be changed in the future to support callable conditions. " + "For now, please use add_edge if you need to specify custom conditions.", + DeprecationWarning, + stacklevel=2, + ) + + for condition_keyword, target in condition_to_target.items(): + self.add_edge(source, target, condition=condition_keyword) + return self + + def set_entry_point(self, name: Union[str, ChatAgent]) -> "DiGraphBuilder": + """Set the default start node of the graph.""" + node_name = self._get_name(name) + if node_name not in self.nodes: + raise ValueError(f"Start node '{node_name}' must be added before setting as entry point.") + self._default_start_node = node_name + return self + + def build(self) -> DiGraph: + """Build and validate the DiGraph.""" + graph = DiGraph( + nodes=self.nodes, + default_start_node=self._default_start_node, + ) + graph.graph_validate() + return graph + + def get_participants(self) -> list[ChatAgent]: + """Return the list of agents in the builder, in insertion order.""" + return list(self.agents.values()) diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/__init__.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/__init__.py new file mode 100644 index 0000000..8ad3b38 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/__init__.py @@ -0,0 +1,5 @@ +from ._magentic_one_group_chat import MagenticOneGroupChat + +__all__ = [ + "MagenticOneGroupChat", +] diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/_magentic_one_group_chat.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/_magentic_one_group_chat.py new file mode 100644 index 0000000..80a848c --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/_magentic_one_group_chat.py @@ -0,0 +1,209 @@ +import asyncio +import logging +from typing import Callable, List + +from agentdhal_core import AgentRuntime, Component, ComponentModel +from agentdhal_core.models import ChatCompletionClient +from pydantic import BaseModel +from typing_extensions import Self + +from .... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME +from ....base import ChatAgent, TerminationCondition +from ....messages import BaseAgentEvent, BaseChatMessage, MessageFactory +from .._base_group_chat import BaseGroupChat +from .._events import GroupChatTermination +from ._magentic_one_orchestrator import MagenticOneOrchestrator +from ._prompts import ORCHESTRATOR_FINAL_ANSWER_PROMPT + +trace_logger = logging.getLogger(TRACE_LOGGER_NAME) +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + + +class MagenticOneGroupChatConfig(BaseModel): + """The declarative configuration for a MagenticOneGroupChat.""" + + name: str | None = None + description: str | None = None + participants: List[ComponentModel] + model_client: ComponentModel + termination_condition: ComponentModel | None = None + max_turns: int | None = None + max_stalls: int + final_answer_prompt: str + emit_team_events: bool = False + + +class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig]): + """A team that runs a group chat with participants managed by the MagenticOneOrchestrator. + + The orchestrator handles the conversation flow, ensuring that the task is completed + efficiently by managing the participants' interactions. + + The orchestrator is based on the Magentic-One architecture, which is a generalist multi-agent system for solving complex tasks (see references below). + + Unlike :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` and :class:`~agentdhal_agentchat.teams.SelectorGroupChat`, + the MagenticOneGroupChat does not support using team as participant. + + Args: + participants (List[ChatAgent]): The participants in the group chat. + model_client (ChatCompletionClient): The model client used for generating responses. + termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None. + Without a termination condition, the group chat will run based on the orchestrator logic or until the maximum number of turns is reached. + max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to 20. + max_stalls (int, optional): The maximum number of stalls allowed before re-planning. Defaults to 3. + final_answer_prompt (str, optional): The LLM prompt used to generate the final answer or response from the team's transcript. A default (sensible for GPT-4o class models) is provided. + custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat. + If you are using custom message types or your agents produces custom message types, you need to specify them here. + Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`. + emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False. + + Raises: + ValueError: In orchestration logic if progress ledger does not have required keys or if next speaker is not valid. + + Examples: + + MagenticOneGroupChat with one assistant agent: + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.teams import MagenticOneGroupChat + from agentdhal_agentchat.ui import Console + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + assistant = AssistantAgent( + "Assistant", + model_client=model_client, + ) + team = MagenticOneGroupChat([assistant], model_client=model_client) + await Console(team.run_stream(task="Provide a different proof to Fermat last theorem")) + + + asyncio.run(main()) + + References: + + If you use the MagenticOneGroupChat in your work, please cite the following paper: + + .. code-block:: bibtex + + @article{fourney2024magentic, + title={Magentic-one: A generalist multi-agent system for solving complex tasks}, + author={Fourney, Adam and Bansal, Gagan and Mozannar, Hussein and Tan, Cheng and Salinas, Eduardo and Niedtner, Friederike and Proebsting, Grace and Bassman, Griffin and Gerrits, Jack and Alber, Jacob and others}, + journal={arXiv preprint arXiv:2411.04468}, + year={2024} + } + """ + + component_config_schema = MagenticOneGroupChatConfig + component_provider_override = "agentdhal_agentchat.teams.MagenticOneGroupChat" + + DEFAULT_NAME = "MagenticOneGroupChat" + DEFAULT_DESCRIPTION = "A team of agents." + + def __init__( + self, + participants: List[ChatAgent], + model_client: ChatCompletionClient, + *, + name: str | None = None, + description: str | None = None, + termination_condition: TerminationCondition | None = None, + max_turns: int | None = 20, + runtime: AgentRuntime | None = None, + max_stalls: int = 3, + final_answer_prompt: str = ORCHESTRATOR_FINAL_ANSWER_PROMPT, + custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, + emit_team_events: bool = False, + ): + for participant in participants: + if not isinstance(participant, ChatAgent): + raise TypeError(f"Participant {participant} must be a ChatAgent.") + super().__init__( + name=name or self.DEFAULT_NAME, + description=description or self.DEFAULT_DESCRIPTION, + participants=list(participants), + group_chat_manager_name="MagenticOneOrchestrator", + group_chat_manager_class=MagenticOneOrchestrator, + termination_condition=termination_condition, + max_turns=max_turns, + runtime=runtime, + custom_message_types=custom_message_types, + emit_team_events=emit_team_events, + ) + + # Validate the participants. + if len(participants) == 0: + raise ValueError("At least one participant is required for MagenticOneGroupChat.") + self._model_client = model_client + self._max_stalls = max_stalls + self._final_answer_prompt = final_answer_prompt + + def _create_group_chat_manager_factory( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + max_turns: int | None, + message_factory: MessageFactory, + ) -> Callable[[], MagenticOneOrchestrator]: + return lambda: MagenticOneOrchestrator( + name, + group_topic_type, + output_topic_type, + participant_topic_types, + participant_names, + participant_descriptions, + max_turns, + message_factory, + self._model_client, + self._max_stalls, + self._final_answer_prompt, + output_message_queue, + termination_condition, + self._emit_team_events, + ) + + def _to_config(self) -> MagenticOneGroupChatConfig: + participants = [participant.dump_component() for participant in self._participants] + termination_condition = self._termination_condition.dump_component() if self._termination_condition else None + return MagenticOneGroupChatConfig( + name=self.name, + description=self.description, + participants=participants, + model_client=self._model_client.dump_component(), + termination_condition=termination_condition, + max_turns=self._max_turns, + max_stalls=self._max_stalls, + final_answer_prompt=self._final_answer_prompt, + emit_team_events=self._emit_team_events, + ) + + @classmethod + def _from_config(cls, config: MagenticOneGroupChatConfig) -> Self: + participants = [ChatAgent.load_component(participant) for participant in config.participants] + model_client = ChatCompletionClient.load_component(config.model_client) + termination_condition = ( + TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None + ) + return cls( + participants=participants, + name=config.name, + description=config.description, + model_client=model_client, + termination_condition=termination_condition, + max_turns=config.max_turns, + max_stalls=config.max_stalls, + final_answer_prompt=config.final_answer_prompt, + emit_team_events=config.emit_team_events, + ) diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py new file mode 100644 index 0000000..80c039c --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -0,0 +1,536 @@ +import asyncio +import json +import logging +import re +from typing import Any, Dict, List, Mapping, Sequence + +from agentdhal_core import AgentId, CancellationToken, DefaultTopicId, MessageContext, event, rpc +from agentdhal_core.models import ( + AssistantMessage, + ChatCompletionClient, + LLMMessage, + UserMessage, +) +from agentdhal_core.utils import extract_json_from_str + +from .... import TRACE_LOGGER_NAME +from ....base import Response, TerminationCondition +from ....messages import ( + BaseAgentEvent, + BaseChatMessage, + HandoffMessage, + MessageFactory, + MultiModalMessage, + SelectSpeakerEvent, + StopMessage, + TextMessage, + ToolCallExecutionEvent, + ToolCallRequestEvent, + ToolCallSummaryMessage, +) +from ....state import MagenticOneOrchestratorState +from ....utils import remove_images +from .._base_group_chat_manager import BaseGroupChatManager +from .._events import ( + GroupChatAgentResponse, + GroupChatMessage, + GroupChatRequestPublish, + GroupChatReset, + GroupChatStart, + GroupChatTeamResponse, + GroupChatTermination, + SerializableException, +) +from ._prompts import ( + ORCHESTRATOR_FINAL_ANSWER_PROMPT, + ORCHESTRATOR_PROGRESS_LEDGER_PROMPT, + ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT, + ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT, + ORCHESTRATOR_TASK_LEDGER_FULL_PROMPT, + ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT, + ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT, + LedgerEntry, +) + +trace_logger = logging.getLogger(TRACE_LOGGER_NAME) + + +class MagenticOneOrchestrator(BaseGroupChatManager): + """The MagenticOneOrchestrator manages a group chat with ledger based orchestration.""" + + def __init__( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + max_turns: int | None, + message_factory: MessageFactory, + model_client: ChatCompletionClient, + max_stalls: int, + final_answer_prompt: str, + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + emit_team_events: bool, + ): + super().__init__( + name, + group_topic_type, + output_topic_type, + participant_topic_types, + participant_names, + participant_descriptions, + output_message_queue, + termination_condition, + max_turns, + message_factory, + emit_team_events=emit_team_events, + ) + self._model_client = model_client + self._max_stalls = max_stalls + self._final_answer_prompt = final_answer_prompt + self._max_json_retries = 10 + self._task = "" + self._facts = "" + self._plan = "" + self._n_rounds = 0 + self._n_stalls = 0 + + # Produce a team description. Each agent sould appear on a single line. + self._team_description = "" + for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True): + self._team_description += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n" + self._team_description = self._team_description.strip() + + def _get_task_ledger_facts_prompt(self, task: str) -> str: + return ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT.format(task=task) + + def _get_task_ledger_plan_prompt(self, team: str) -> str: + return ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT.format(team=team) + + def _get_task_ledger_full_prompt(self, task: str, team: str, facts: str, plan: str) -> str: + return ORCHESTRATOR_TASK_LEDGER_FULL_PROMPT.format(task=task, team=team, facts=facts, plan=plan) + + def _get_progress_ledger_prompt(self, task: str, team: str, names: List[str]) -> str: + return ORCHESTRATOR_PROGRESS_LEDGER_PROMPT.format(task=task, team=team, names=", ".join(names)) + + def _get_task_ledger_facts_update_prompt(self, task: str, facts: str) -> str: + return ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT.format(task=task, facts=facts) + + def _get_task_ledger_plan_update_prompt(self, team: str) -> str: + return ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT.format(team=team) + + def _get_final_answer_prompt(self, task: str) -> str: + if self._final_answer_prompt == ORCHESTRATOR_FINAL_ANSWER_PROMPT: + return ORCHESTRATOR_FINAL_ANSWER_PROMPT.format(task=task) + else: + return self._final_answer_prompt + + async def _log_message(self, log_message: str) -> None: + trace_logger.debug(log_message) + + @rpc + async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: # type: ignore + """Handle the start of a task.""" + + # Check if the conversation has already terminated. + if self._termination_condition is not None and self._termination_condition.terminated: + early_stop_message = StopMessage(content="The group chat has already terminated.", source=self._name) + # Signal termination. + await self._signal_termination(early_stop_message) + # Stop the group chat. + return + assert message is not None and message.messages is not None + + # Validate the group state given all the messages. + await self.validate_group_state(message.messages) + + # Log the message to the output topic. + await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type)) + # Log the message to the output queue. + for msg in message.messages: + await self._output_message_queue.put(msg) + + # Outer Loop for first time + # Create the initial task ledger + ################################# + # Combine all message contents for task + self._task = " ".join([msg.to_model_text() for msg in message.messages]) + planning_conversation: List[LLMMessage] = [] + + # 1. GATHER FACTS + # create a closed book task and generate a response and update the chat history + planning_conversation.append( + UserMessage(content=self._get_task_ledger_facts_prompt(self._task), source=self._name) + ) + response = await self._model_client.create( + self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token + ) + + assert isinstance(response.content, str) + self._facts = response.content + planning_conversation.append(AssistantMessage(content=self._facts, source=self._name)) + + # 2. CREATE A PLAN + ## plan based on available information + planning_conversation.append( + UserMessage(content=self._get_task_ledger_plan_prompt(self._team_description), source=self._name) + ) + response = await self._model_client.create( + self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token + ) + + assert isinstance(response.content, str) + self._plan = response.content + + # Kick things off + self._n_stalls = 0 + await self._reenter_outer_loop(ctx.cancellation_token) + + @event + async def handle_agent_response( # type: ignore + self, message: GroupChatAgentResponse | GroupChatTeamResponse, ctx: MessageContext + ) -> None: # type: ignore + try: + if not isinstance(message, GroupChatAgentResponse): + raise RuntimeError("MagenticOneOrchestrator does not support GroupChatTeamResponse messages.") + delta: List[BaseAgentEvent | BaseChatMessage] = [] + if message.response.inner_messages is not None: + for inner_message in message.response.inner_messages: + delta.append(inner_message) + await self.update_message_thread([message.response.chat_message]) + delta.append(message.response.chat_message) + + if self._termination_condition is not None: + stop_message = await self._termination_condition(delta) + if stop_message is not None: + # Reset the termination conditions. + await self._termination_condition.reset() + # Signal termination. + await self._signal_termination(stop_message) + return + + await self._orchestrate_step(ctx.cancellation_token) + except Exception as e: + error = SerializableException.from_exception(e) + await self._signal_termination_with_error(error) + # Raise the error to the runtime. + raise + + async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: + pass + + async def save_state(self) -> Mapping[str, Any]: + state = MagenticOneOrchestratorState( + message_thread=[msg.dump() for msg in self._message_thread], + current_turn=self._current_turn, + task=self._task, + facts=self._facts, + plan=self._plan, + n_rounds=self._n_rounds, + n_stalls=self._n_stalls, + ) + return state.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + orchestrator_state = MagenticOneOrchestratorState.model_validate(state) + self._message_thread = [self._message_factory.create(message) for message in orchestrator_state.message_thread] + self._current_turn = orchestrator_state.current_turn + self._task = orchestrator_state.task + self._facts = orchestrator_state.facts + self._plan = orchestrator_state.plan + self._n_rounds = orchestrator_state.n_rounds + self._n_stalls = orchestrator_state.n_stalls + + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: + """Not used in this orchestrator, we select next speaker in _orchestrate_step.""" + return [""] + + async def reset(self) -> None: + """Reset the group chat manager.""" + self._message_thread.clear() + if self._termination_condition is not None: + await self._termination_condition.reset() + self._n_rounds = 0 + self._n_stalls = 0 + self._task = "" + self._facts = "" + self._plan = "" + + async def _reenter_outer_loop(self, cancellation_token: CancellationToken) -> None: + """Re-enter Outer loop of the orchestrator after creating task ledger.""" + # Reset the agents + for participant_topic_type in self._participant_name_to_topic_type.values(): + await self._runtime.send_message( + GroupChatReset(), + recipient=AgentId(type=participant_topic_type, key=self.id.key), + cancellation_token=cancellation_token, + ) + # Reset partially the group chat manager + self._message_thread.clear() + + # Prepare the ledger + ledger_message = TextMessage( + content=self._get_task_ledger_full_prompt(self._task, self._team_description, self._facts, self._plan), + source=self._name, + ) + + # Save my copy + await self.update_message_thread([ledger_message]) + + # Log it to the output topic. + await self.publish_message( + GroupChatMessage(message=ledger_message), + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + # Log it to the output queue. + await self._output_message_queue.put(ledger_message) + + # Broadcast + await self.publish_message( + GroupChatAgentResponse(response=Response(chat_message=ledger_message), name=self._name), + topic_id=DefaultTopicId(type=self._group_topic_type), + ) + + # Restart the inner loop + await self._orchestrate_step(cancellation_token=cancellation_token) + + async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None: + """Implements the inner loop of the orchestrator and selects next speaker.""" + # Check if we reached the maximum number of rounds + if self._max_turns is not None and self._n_rounds > self._max_turns: + await self._prepare_final_answer("Max rounds reached.", cancellation_token) + return + self._n_rounds += 1 + + # Update the progress ledger + context = self._thread_to_context() + + progress_ledger_prompt = self._get_progress_ledger_prompt( + self._task, self._team_description, self._participant_names + ) + context.append(UserMessage(content=progress_ledger_prompt, source=self._name)) + progress_ledger: Dict[str, Any] = {} + assert self._max_json_retries > 0 + key_error: bool = False + for _ in range(self._max_json_retries): + if self._model_client.model_info.get("structured_output", False): + response = await self._model_client.create( + self._get_compatible_context(context), json_output=LedgerEntry + ) + elif self._model_client.model_info.get("json_output", False): + response = await self._model_client.create( + self._get_compatible_context(context), cancellation_token=cancellation_token, json_output=True + ) + else: + response = await self._model_client.create( + self._get_compatible_context(context), cancellation_token=cancellation_token + ) + ledger_str = response.content + try: + assert isinstance(ledger_str, str) + output_json = extract_json_from_str(ledger_str) + if len(output_json) != 1: + raise ValueError( + f"Progress ledger should contain a single JSON object, but found: {len(progress_ledger)}" + ) + progress_ledger = output_json[0] + + # If the team consists of a single agent, deterministically set the next speaker + if len(self._participant_names) == 1: + progress_ledger["next_speaker"] = { + "reason": "The team consists of only one agent.", + "answer": self._participant_names[0], + } + + # Validate the structure + required_keys = [ + "is_request_satisfied", + "is_progress_being_made", + "is_in_loop", + "instruction_or_question", + "next_speaker", + ] + + key_error = False + for key in required_keys: + if ( + key not in progress_ledger + or not isinstance(progress_ledger[key], dict) + or "answer" not in progress_ledger[key] + or "reason" not in progress_ledger[key] + ): + key_error = True + break + + # Validate the next speaker if the task is not yet complete + if ( + not progress_ledger["is_request_satisfied"]["answer"] + and progress_ledger["next_speaker"]["answer"] not in self._participant_names + ): + key_error = True + break + + if not key_error: + break + await self._log_message(f"Failed to parse ledger information, retrying: {ledger_str}") + except (json.JSONDecodeError, TypeError): + key_error = True + await self._log_message("Invalid ledger format encountered, retrying...") + continue + if key_error: + raise ValueError("Failed to parse ledger information after multiple retries.") + await self._log_message(f"Progress Ledger: {progress_ledger}") + + # Check for task completion + if progress_ledger["is_request_satisfied"]["answer"]: + await self._log_message("Task completed, preparing final answer...") + await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"], cancellation_token) + return + + # Check for stalling + if not progress_ledger["is_progress_being_made"]["answer"]: + self._n_stalls += 1 + elif progress_ledger["is_in_loop"]["answer"]: + self._n_stalls += 1 + else: + self._n_stalls = max(0, self._n_stalls - 1) + + # Too much stalling + if self._n_stalls >= self._max_stalls: + await self._log_message("Stall count exceeded, re-planning with the outer loop...") + await self._update_task_ledger(cancellation_token) + await self._reenter_outer_loop(cancellation_token) + return + + # Broadcast the next step + message = TextMessage(content=progress_ledger["instruction_or_question"]["answer"], source=self._name) + await self.update_message_thread([message]) # My copy + + await self._log_message(f"Next Speaker: {progress_ledger['next_speaker']['answer']}") + # Log it to the output topic. + await self.publish_message( + GroupChatMessage(message=message), + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + # Log it to the output queue. + await self._output_message_queue.put(message) + + # Broadcast it + await self.publish_message( # Broadcast + GroupChatAgentResponse(response=Response(chat_message=message), name=self._name), + topic_id=DefaultTopicId(type=self._group_topic_type), + cancellation_token=cancellation_token, + ) + + # Request that the step be completed + next_speaker = progress_ledger["next_speaker"]["answer"] + # Check if the next speaker is valid + if next_speaker not in self._participant_name_to_topic_type: + raise ValueError( + f"Invalid next speaker: {next_speaker} from the ledger, participants are: {self._participant_names}" + ) + participant_topic_type = self._participant_name_to_topic_type[next_speaker] + await self.publish_message( + GroupChatRequestPublish(), + topic_id=DefaultTopicId(type=participant_topic_type), + cancellation_token=cancellation_token, + ) + + # Send the message to the next speaker + if self._emit_team_events: + select_msg = SelectSpeakerEvent(content=[next_speaker], source=self._name) + await self.publish_message( + GroupChatMessage(message=select_msg), + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + await self._output_message_queue.put(select_msg) + + async def _update_task_ledger(self, cancellation_token: CancellationToken) -> None: + """Update the task ledger (outer loop) with the latest facts and plan.""" + context = self._thread_to_context() + + # Update the facts + update_facts_prompt = self._get_task_ledger_facts_update_prompt(self._task, self._facts) + context.append(UserMessage(content=update_facts_prompt, source=self._name)) + + response = await self._model_client.create( + self._get_compatible_context(context), cancellation_token=cancellation_token + ) + + assert isinstance(response.content, str) + self._facts = response.content + context.append(AssistantMessage(content=self._facts, source=self._name)) + + # Update the plan + update_plan_prompt = self._get_task_ledger_plan_update_prompt(self._team_description) + context.append(UserMessage(content=update_plan_prompt, source=self._name)) + + response = await self._model_client.create( + self._get_compatible_context(context), cancellation_token=cancellation_token + ) + + assert isinstance(response.content, str) + self._plan = response.content + + async def _prepare_final_answer(self, reason: str, cancellation_token: CancellationToken) -> None: + """Prepare the final answer for the task.""" + context = self._thread_to_context() + + # Get the final answer + final_answer_prompt = self._get_final_answer_prompt(self._task) + context.append(UserMessage(content=final_answer_prompt, source=self._name)) + + response = await self._model_client.create( + self._get_compatible_context(context), cancellation_token=cancellation_token + ) + assert isinstance(response.content, str) + message = TextMessage(content=response.content, source=self._name) + + await self.update_message_thread([message]) # My copy + + # Log it to the output topic. + await self.publish_message( + GroupChatMessage(message=message), + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + # Log it to the output queue. + await self._output_message_queue.put(message) + + # Broadcast + await self.publish_message( + GroupChatAgentResponse(response=Response(chat_message=message), name=self._name), + topic_id=DefaultTopicId(type=self._group_topic_type), + cancellation_token=cancellation_token, + ) + + if self._termination_condition is not None: + await self._termination_condition.reset() + # Signal termination + await self._signal_termination(StopMessage(content=reason, source=self._name)) + + def _thread_to_context(self) -> List[LLMMessage]: + """Convert the message thread to a context for the model.""" + context: List[LLMMessage] = [] + for m in self._message_thread: + if isinstance(m, ToolCallRequestEvent | ToolCallExecutionEvent): + # Ignore tool call messages. + continue + elif isinstance(m, StopMessage | HandoffMessage): + context.append(UserMessage(content=m.content, source=m.source)) + elif m.source == self._name: + assert isinstance(m, TextMessage | ToolCallSummaryMessage) + context.append(AssistantMessage(content=m.content, source=m.source)) + else: + assert isinstance(m, (TextMessage, MultiModalMessage, ToolCallSummaryMessage)) + context.append(UserMessage(content=m.content, source=m.source)) + return context + + def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]: + """Ensure that the messages are compatible with the underlying client, by removing images if needed.""" + if self._model_client.model_info["vision"]: + return messages + else: + return remove_images(messages) diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/_prompts.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/_prompts.py new file mode 100644 index 0000000..846d069 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_magentic_one/_prompts.py @@ -0,0 +1,149 @@ +from pydantic import BaseModel + +ORCHESTRATOR_SYSTEM_MESSAGE = "" + + +ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT = """Below I will present you a request. Before we begin addressing the request, please answer the following pre-survey to the best of your ability. Keep in mind that you are Ken Jennings-level with trivia, and Mensa-level with puzzles, so there should be a deep well to draw from. + +Here is the request: + +{task} + +Here is the pre-survey: + + 1. Please list any specific facts or figures that are GIVEN in the request itself. It is possible that there are none. + 2. Please list any facts that may need to be looked up, and WHERE SPECIFICALLY they might be found. In some cases, authoritative sources are mentioned in the request itself. + 3. Please list any facts that may need to be derived (e.g., via logical deduction, simulation, or computation) + 4. Please list any facts that are recalled from memory, hunches, well-reasoned guesses, etc. + +When answering this survey, keep in mind that "facts" will typically be specific names, dates, statistics, etc. Your answer should use headings: + + 1. GIVEN OR VERIFIED FACTS + 2. FACTS TO LOOK UP + 3. FACTS TO DERIVE + 4. EDUCATED GUESSES + +DO NOT include any other headings or sections in your response. DO NOT list next steps or plans until asked to do so. +""" + + +ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT = """Fantastic. To address this request we have assembled the following team: + +{team} + +Based on the team composition, and known and unknown facts, please devise a short bullet-point plan for addressing the original request. Remember, there is no requirement to involve all team members -- a team member's particular expertise may not be needed for this task.""" + + +ORCHESTRATOR_TASK_LEDGER_FULL_PROMPT = """ +We are working to address the following user request: + +{task} + + +To answer this request we have assembled the following team: + +{team} + + +Here is an initial fact sheet to consider: + +{facts} + + +Here is the plan to follow as best as possible: + +{plan} +""" + + +ORCHESTRATOR_PROGRESS_LEDGER_PROMPT = """ +Recall we are working on the following request: + +{task} + +And we have assembled the following team: + +{team} + +To make progress on the request, please answer the following questions, including necessary reasoning: + + - Is the request fully satisfied? (True if complete, or False if the original request has yet to be SUCCESSFULLY and FULLY addressed) + - Are we in a loop where we are repeating the same requests and / or getting the same responses as before? Loops can span multiple turns, and can include repeated actions like scrolling up or down more than a handful of times. + - Are we making forward progress? (True if just starting, or recent messages are adding value. False if recent messages show evidence of being stuck in a loop or if there is evidence of significant barriers to success such as the inability to read from a required file) + - Who should speak next? (select from: {names}) + - What instruction or question would you give this team member? (Phrase as if speaking directly to them, and include any specific information they may need) + +Please output an answer in pure JSON format according to the following schema. The JSON object must be parsable as-is. DO NOT OUTPUT ANYTHING OTHER THAN JSON, AND DO NOT DEVIATE FROM THIS SCHEMA: + + {{ + "is_request_satisfied": {{ + "reason": string, + "answer": boolean + }}, + "is_in_loop": {{ + "reason": string, + "answer": boolean + }}, + "is_progress_being_made": {{ + "reason": string, + "answer": boolean + }}, + "next_speaker": {{ + "reason": string, + "answer": string (select from: {names}) + }}, + "instruction_or_question": {{ + "reason": string, + "answer": string + }} + }} +""" + + +class LedgerEntryBooleanAnswer(BaseModel): + reason: str + answer: bool + + +class LedgerEntryStringAnswer(BaseModel): + reason: str + answer: str + + +class LedgerEntry(BaseModel): + is_request_satisfied: LedgerEntryBooleanAnswer + is_in_loop: LedgerEntryBooleanAnswer + is_progress_being_made: LedgerEntryBooleanAnswer + next_speaker: LedgerEntryStringAnswer + instruction_or_question: LedgerEntryStringAnswer + + +ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT = """As a reminder, we are working to solve the following task: + +{task} + +It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned that may be helpful. Example edits can include (but are not limited to) adding new guesses, moving educated guesses to verified facts if appropriate, etc. Updates may be made to any section of the fact sheet, and more than one section of the fact sheet can be edited. This is an especially good time to update educated guesses, so please at least add or update one educated guess or hunch, and explain your reasoning. + +Here is the old fact sheet: + +{facts} +""" + + +ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT = """Please briefly explain what went wrong on this last run (the root cause of the failure), and then come up with a new plan that takes steps and/or includes hints to overcome prior challenges and especially avoids repeating the same mistakes. As before, the new plan should be concise, be expressed in bullet-point form, and consider the following team composition (do not involve any other outside people since we cannot contact anyone else): + +{team} +""" + + +ORCHESTRATOR_FINAL_ANSWER_PROMPT = """ +We are working on the following task: +{task} + +We have completed the task. + +The above messages contain the conversation that took place to complete the task. + +Based on the information gathered, provide the final answer to the original request. +The answer should be phrased as if you were speaking to the user. +""" diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_round_robin_group_chat.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_round_robin_group_chat.py new file mode 100644 index 0000000..5d447ba --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -0,0 +1,328 @@ +import asyncio +from typing import Any, Callable, List, Mapping, Sequence + +from agentdhal_core import AgentRuntime, Component, ComponentModel +from pydantic import BaseModel +from typing_extensions import Self + +from ...base import ChatAgent, Team, TerminationCondition +from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory +from ...state import RoundRobinManagerState +from ._base_group_chat import BaseGroupChat +from ._base_group_chat_manager import BaseGroupChatManager +from ._events import GroupChatTermination + + +class RoundRobinGroupChatManager(BaseGroupChatManager): + """A group chat manager that selects the next speaker in a round-robin fashion.""" + + def __init__( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + max_turns: int | None, + message_factory: MessageFactory, + emit_team_events: bool, + ) -> None: + super().__init__( + name, + group_topic_type, + output_topic_type, + participant_topic_types, + participant_names, + participant_descriptions, + output_message_queue, + termination_condition, + max_turns, + message_factory, + emit_team_events, + ) + self._next_speaker_index = 0 + + async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: + pass + + async def reset(self) -> None: + self._current_turn = 0 + self._message_thread.clear() + if self._termination_condition is not None: + await self._termination_condition.reset() + self._next_speaker_index = 0 + + async def save_state(self) -> Mapping[str, Any]: + state = RoundRobinManagerState( + message_thread=[message.dump() for message in self._message_thread], + current_turn=self._current_turn, + next_speaker_index=self._next_speaker_index, + ) + return state.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + round_robin_state = RoundRobinManagerState.model_validate(state) + self._message_thread = [self._message_factory.create(message) for message in round_robin_state.message_thread] + self._current_turn = round_robin_state.current_turn + self._next_speaker_index = round_robin_state.next_speaker_index + + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: + """Select a speaker from the participants in a round-robin fashion. + + .. note:: + + This method always returns a single speaker. + """ + current_speaker_index = self._next_speaker_index + self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names) + current_speaker = self._participant_names[current_speaker_index] + return current_speaker + + +class RoundRobinGroupChatConfig(BaseModel): + """The declarative configuration RoundRobinGroupChat.""" + + name: str | None = None + description: str | None = None + participants: List[ComponentModel] + termination_condition: ComponentModel | None = None + max_turns: int | None = None + emit_team_events: bool = False + + +class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]): + """A team that runs a group chat with participants taking turns in a round-robin fashion + to publish a message to all. + + If an :class:`~agentdhal_agentchat.base.ChatAgent` is a participant, + the :class:`~agentdhal_agentchat.messages.BaseChatMessage` from the agent response's + :attr:`~agentdhal_agentchat.base.Response.chat_message` will be published + to other participants in the group chat. + + If a :class:`~agentdhal_agentchat.base.Team` is a participant, + the :class:`~agentdhal_agentchat.messages.BaseChatMessage` + from the team result' :attr:`~agentdhal_agentchat.base.TaskResult.messages` will be published + to other participants in the group chat. + + If a single participant is in the team, the participant will be the only speaker. + + Args: + participants (List[ChatAgent | Team]): The participants in the group chat. + name (str | None, optional): The name of the group chat, using :attr:`~agentdhal_agentchat.teams.RoundRobinGroupChat.DEFAULT_NAME` if not provided. + The name is used by a parent team to identify this group chat so it must be unique within the parent team. + description (str | None, optional): The description of the group chat, using :attr:`~agentdhal_agentchat.teams.RoundRobinGroupChat.DEFAULT_DESCRIPTION` if not provided. + termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None. + Without a termination condition, the group chat will run indefinitely. + max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit. + custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat. + If you are using custom message types or your agents produces custom message types, you need to specify them here. + Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`. + emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False. + + Raises: + ValueError: If no participants are provided or if participant names are not unique. + + Examples: + + A team with one participant with tools: + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_agentchat.conditions import TextMentionTermination + from agentdhal_agentchat.ui import Console + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + async def get_weather(location: str) -> str: + return f"The weather in {location} is sunny." + + assistant = AssistantAgent( + "Assistant", + model_client=model_client, + tools=[get_weather], + ) + termination = TextMentionTermination("TERMINATE") + team = RoundRobinGroupChat([assistant], termination_condition=termination) + await Console(team.run_stream(task="What's the weather in New York?")) + + + asyncio.run(main()) + + A team with multiple participants: + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_agentchat.conditions import TextMentionTermination + from agentdhal_agentchat.ui import Console + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + termination = TextMentionTermination("TERMINATE") + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + await Console(team.run_stream(task="Tell me some jokes.")) + + + asyncio.run(main()) + + A team of user proxy and a nested team of writer and reviewer agents: + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.agents import UserProxyAgent, AssistantAgent + from agentdhal_agentchat.conditions import TextMentionTermination, MaxMessageTermination + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") + + writer = AssistantAgent( + "writer", model_client=model_client, system_message="You are a writer.", model_client_stream=True + ) + + reviewer = AssistantAgent( + "reviewer", + model_client=model_client, + system_message="Provide feedback to the input and suggest improvements.", + model_client_stream=True, + ) + + # NOTE: you can skip input by pressing Enter. + user_proxy = UserProxyAgent("user_proxy") + + # Maximum 1 round of review and revision. + inner_termination = MaxMessageTermination(max_messages=4) + + # The outter-loop termination condition that will terminate the team when the user types "exit". + outter_termination = TextMentionTermination("exit", sources=["user_proxy"]) + + team = RoundRobinGroupChat( + [ + # For each turn, the writer writes a summary and the reviewer reviews it. + RoundRobinGroupChat([writer, reviewer], termination_condition=inner_termination), + # The user proxy gets user input once the writer and reviewer have finished their actions. + user_proxy, + ], + termination_condition=outter_termination, + ) + # Start the team and wait for it to terminate. + await Console(team.run_stream(task="Write a short essay about the impact of AI on society.")) + + + asyncio.run(main()) + """ + + component_config_schema = RoundRobinGroupChatConfig + component_provider_override = "agentdhal_agentchat.teams.RoundRobinGroupChat" + + DEFAULT_NAME = "RoundRobinGroupChat" + DEFAULT_DESCRIPTION = "A team of agents." + + def __init__( + self, + participants: List[ChatAgent | Team], + *, + name: str | None = None, + description: str | None = None, + termination_condition: TerminationCondition | None = None, + max_turns: int | None = None, + runtime: AgentRuntime | None = None, + custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, + emit_team_events: bool = False, + ) -> None: + super().__init__( + name=name or self.DEFAULT_NAME, + description=description or self.DEFAULT_DESCRIPTION, + participants=participants, + group_chat_manager_name="RoundRobinGroupChatManager", + group_chat_manager_class=RoundRobinGroupChatManager, + termination_condition=termination_condition, + max_turns=max_turns, + runtime=runtime, + custom_message_types=custom_message_types, + emit_team_events=emit_team_events, + ) + + def _create_group_chat_manager_factory( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + max_turns: int | None, + message_factory: MessageFactory, + ) -> Callable[[], RoundRobinGroupChatManager]: + def _factory() -> RoundRobinGroupChatManager: + return RoundRobinGroupChatManager( + name, + group_topic_type, + output_topic_type, + participant_topic_types, + participant_names, + participant_descriptions, + output_message_queue, + termination_condition, + max_turns, + message_factory, + self._emit_team_events, + ) + + return _factory + + def _to_config(self) -> RoundRobinGroupChatConfig: + participants = [participant.dump_component() for participant in self._participants] + termination_condition = self._termination_condition.dump_component() if self._termination_condition else None + return RoundRobinGroupChatConfig( + name=self._name, + description=self._description, + participants=participants, + termination_condition=termination_condition, + max_turns=self._max_turns, + emit_team_events=self._emit_team_events, + ) + + @classmethod + def _from_config(cls, config: RoundRobinGroupChatConfig) -> Self: + participants: List[ChatAgent | Team] = [] + for participant in config.participants: + if participant.component_type == Team.component_type: + participants.append(Team.load_component(participant)) + else: + participants.append(ChatAgent.load_component(participant)) + + termination_condition = ( + TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None + ) + return cls( + participants, + name=config.name, + description=config.description, + termination_condition=termination_condition, + max_turns=config.max_turns, + emit_team_events=config.emit_team_events, + ) diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_selector_group_chat.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_selector_group_chat.py new file mode 100644 index 0000000..c7ffe3c --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_selector_group_chat.py @@ -0,0 +1,730 @@ +import asyncio +import logging +import re +from inspect import iscoroutinefunction +from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast + +from agentdhal_core import AgentRuntime, CancellationToken, Component, ComponentModel +from agentdhal_core.model_context import ( + ChatCompletionContext, + UnboundedChatCompletionContext, +) +from agentdhal_core.models import ( + AssistantMessage, + ChatCompletionClient, + CreateResult, + LLMMessage, + ModelFamily, + SystemMessage, + UserMessage, +) +from pydantic import BaseModel +from typing_extensions import Self + +from ... import TRACE_LOGGER_NAME +from ...base import ChatAgent, Team, TerminationCondition +from ...messages import ( + BaseAgentEvent, + BaseChatMessage, + HandoffMessage, + MessageFactory, + ModelClientStreamingChunkEvent, + SelectorEvent, +) +from ...state import SelectorManagerState +from ._base_group_chat import BaseGroupChat +from ._base_group_chat_manager import BaseGroupChatManager +from ._events import GroupChatTermination + +trace_logger = logging.getLogger(TRACE_LOGGER_NAME) + +SyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None] +AsyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]] +SelectorFuncType = Union[SyncSelectorFunc | AsyncSelectorFunc] + +SyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]] +AsyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]] +CandidateFuncType = Union[SyncCandidateFunc | AsyncCandidateFunc] + + +class SelectorGroupChatManager(BaseGroupChatManager): + """A group chat manager that selects the next speaker using a ChatCompletion + model and a custom selector function.""" + + def __init__( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + max_turns: int | None, + message_factory: MessageFactory, + model_client: ChatCompletionClient, + selector_prompt: str, + allow_repeated_speaker: bool, + selector_func: Optional[SelectorFuncType], + max_selector_attempts: int, + candidate_func: Optional[CandidateFuncType], + emit_team_events: bool, + model_context: ChatCompletionContext | None, + model_client_streaming: bool = False, + ) -> None: + super().__init__( + name, + group_topic_type, + output_topic_type, + participant_topic_types, + participant_names, + participant_descriptions, + output_message_queue, + termination_condition, + max_turns, + message_factory, + emit_team_events, + ) + self._model_client = model_client + self._selector_prompt = selector_prompt + self._previous_speaker: str | None = None + self._allow_repeated_speaker = allow_repeated_speaker + self._selector_func = selector_func + self._is_selector_func_async = iscoroutinefunction(self._selector_func) + self._max_selector_attempts = max_selector_attempts + self._candidate_func = candidate_func + self._is_candidate_func_async = iscoroutinefunction(self._candidate_func) + self._model_client_streaming = model_client_streaming + if model_context is not None: + self._model_context = model_context + else: + self._model_context = UnboundedChatCompletionContext() + self._cancellation_token = CancellationToken() + + async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: + pass + + async def reset(self) -> None: + self._current_turn = 0 + self._message_thread.clear() + await self._model_context.clear() + if self._termination_condition is not None: + await self._termination_condition.reset() + self._previous_speaker = None + + async def save_state(self) -> Mapping[str, Any]: + state = SelectorManagerState( + message_thread=[msg.dump() for msg in self._message_thread], + current_turn=self._current_turn, + previous_speaker=self._previous_speaker, + ) + return state.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + selector_state = SelectorManagerState.model_validate(state) + self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread] + await self._add_messages_to_context( + self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)] + ) + self._current_turn = selector_state.current_turn + self._previous_speaker = selector_state.previous_speaker + + @staticmethod + async def _add_messages_to_context( + model_context: ChatCompletionContext, + messages: Sequence[BaseChatMessage], + ) -> None: + """ + Add incoming messages to the model context. + """ + for msg in messages: + if isinstance(msg, HandoffMessage): + for llm_msg in msg.context: + await model_context.add_message(llm_msg) + await model_context.add_message(msg.to_model_message()) + + async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + self._message_thread.extend(messages) + base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)] + await self._add_messages_to_context(self._model_context, base_chat_messages) + + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: + """Selects the next speaker in a group chat using a ChatCompletion client, + with the selector function as override if it returns a speaker name. + + .. note:: + + This method always returns a single speaker name. + + A key assumption is that the agent type is the same as the topic type, which we use as the agent name. + """ + # Use the selector function if provided. + if self._selector_func is not None: + if self._is_selector_func_async: + async_selector_func = cast(AsyncSelectorFunc, self._selector_func) + speaker = await async_selector_func(thread) + else: + sync_selector_func = cast(SyncSelectorFunc, self._selector_func) + speaker = sync_selector_func(thread) + if speaker is not None: + if speaker not in self._participant_names: + raise ValueError( + f"Selector function returned an invalid speaker name: {speaker}. " + f"Expected one of: {self._participant_names}." + ) + # Skip the model based selection. + return [speaker] + + # Use the candidate function to filter participants if provided + if self._candidate_func is not None: + if self._is_candidate_func_async: + async_candidate_func = cast(AsyncCandidateFunc, self._candidate_func) + participants = await async_candidate_func(thread) + else: + sync_candidate_func = cast(SyncCandidateFunc, self._candidate_func) + participants = sync_candidate_func(thread) + if not participants: + raise ValueError("Candidate function must return a non-empty list of participant names.") + if not all(p in self._participant_names for p in participants): + raise ValueError( + f"Candidate function returned invalid participant names: {participants}. " + f"Expected one of: {self._participant_names}." + ) + else: + # Construct the candidate agent list to be selected from, skip the previous speaker if not allowed. + if self._previous_speaker is not None and not self._allow_repeated_speaker: + participants = [p for p in self._participant_names if p != self._previous_speaker] + else: + participants = list(self._participant_names) + + assert len(participants) > 0 + + # Construct agent roles. + # Each agent sould appear on a single line. + roles = "" + for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True): + roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n" + roles = roles.strip() + + # Select the next speaker. + if len(participants) > 1: + agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts) + else: + agent_name = participants[0] + self._previous_speaker = agent_name + trace_logger.debug(f"Selected speaker: {agent_name}") + return [agent_name] + + def construct_message_history(self, message_history: List[LLMMessage]) -> str: + # Construct the history of the conversation. + history_messages: List[str] = [] + for msg in message_history: + if isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage): + message = f"{msg.source}: {msg.content}" + history_messages.append( + message.rstrip() + "\n\n" + ) # Create some consistency for how messages are separated in the transcript + + history: str = "\n".join(history_messages) + return history + + async def _select_speaker(self, roles: str, participants: List[str], max_attempts: int) -> str: + model_context_messages = await self._model_context.get_messages() + model_context_history = self.construct_message_history(model_context_messages) + + select_speaker_prompt = self._selector_prompt.format( + roles=roles, participants=str(participants), history=model_context_history + ) + + select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage] + if ModelFamily.is_openai(self._model_client.model_info["family"]): + select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] + else: + # Many other models need a UserMessage to respond to + select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="user")] + + num_attempts = 0 + while num_attempts < max_attempts: + num_attempts += 1 + if self._model_client_streaming: + chunk: CreateResult | str = "" + async for _chunk in self._model_client.create_stream(messages=select_speaker_messages): + chunk = _chunk + if self._emit_team_events: + if isinstance(chunk, str): + await self._output_message_queue.put( + ModelClientStreamingChunkEvent(content=cast(str, _chunk), source=self._name) + ) + else: + assert isinstance(chunk, CreateResult) + assert isinstance(chunk.content, str) + await self._output_message_queue.put( + SelectorEvent(content=chunk.content, source=self._name) + ) + # The last chunk must be CreateResult. + assert isinstance(chunk, CreateResult) + response = chunk + else: + response = await self._model_client.create(messages=select_speaker_messages) + assert isinstance(response.content, str) + select_speaker_messages.append(AssistantMessage(content=response.content, source="selector")) + # NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed. + # This is because the model may still select the previous speaker, and we want to catch that. + mentions = self._mentioned_agents(response.content, self._participant_names) + if len(mentions) == 0: + trace_logger.debug(f"Model failed to select a valid name: {response.content} (attempt {num_attempts})") + feedback = f"No valid name was mentioned. Please select from: {str(participants)}." + select_speaker_messages.append(UserMessage(content=feedback, source="user")) + elif len(mentions) > 1: + trace_logger.debug(f"Model selected multiple names: {str(mentions)} (attempt {num_attempts})") + feedback = ( + f"Expected exactly one name to be mentioned. Please select only one from: {str(participants)}." + ) + select_speaker_messages.append(UserMessage(content=feedback, source="user")) + else: + agent_name = list(mentions.keys())[0] + if ( + not self._allow_repeated_speaker + and self._previous_speaker is not None + and agent_name == self._previous_speaker + ): + trace_logger.debug(f"Model selected the previous speaker: {agent_name} (attempt {num_attempts})") + feedback = ( + f"Repeated speaker is not allowed, please select a different name from: {str(participants)}." + ) + select_speaker_messages.append(UserMessage(content=feedback, source="user")) + else: + # Valid selection + trace_logger.debug(f"Model selected a valid name: {agent_name} (attempt {num_attempts})") + return agent_name + + if self._previous_speaker is not None: + trace_logger.warning(f"Model failed to select a speaker after {max_attempts}, using the previous speaker.") + return self._previous_speaker + trace_logger.warning( + f"Model failed to select a speaker after {max_attempts} and there was no previous speaker, using the first participant." + ) + return participants[0] + + def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]: + """Counts the number of times each agent is mentioned in the provided message content. + Agent names will match under any of the following conditions (all case-sensitive): + - Exact name match + - If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer') + - If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer') + + Args: + message_content (Union[str, List]): The content of the message, either as a single string or a list of strings. + agents (List[Agent]): A list of Agent objects, each having a 'name' attribute to be searched in the message content. + + Returns: + Dict: a counter for mentioned agents. + """ + mentions: Dict[str, int] = dict() + for name in agent_names: + # Finds agent mentions, taking word boundaries into account, + # accommodates escaping underscores and underscores as spaces + regex = ( + r"(?<=\W)(" + + re.escape(name) + + r"|" + + re.escape(name.replace("_", " ")) + + r"|" + + re.escape(name.replace("_", r"\_")) + + r")(?=\W)" + ) + # Pad the message to help with matching + count = len(re.findall(regex, f" {message_content} ")) + if count > 0: + mentions[name] = count + return mentions + + +class SelectorGroupChatConfig(BaseModel): + """The declarative configuration for SelectorGroupChat.""" + + name: str | None = None + description: str | None = None + participants: List[ComponentModel] + model_client: ComponentModel + termination_condition: ComponentModel | None = None + max_turns: int | None = None + selector_prompt: str + allow_repeated_speaker: bool + # selector_func: ComponentModel | None + max_selector_attempts: int = 3 + emit_team_events: bool = False + model_client_streaming: bool = False + model_context: ComponentModel | None = None + + +class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): + """A group chat team that have participants takes turn to publish a message + to all, using a ChatCompletion model to select the next speaker after each message. + + If an :class:`~agentdhal_agentchat.base.ChatAgent` is a participant, + the :class:`~agentdhal_agentchat.messages.BaseChatMessage` from the agent response's + :attr:`~agentdhal_agentchat.base.Response.chat_message` will be published + to other participants in the group chat. + + If a :class:`~agentdhal_agentchat.base.Team` is a participant, + the :class:`~agentdhal_agentchat.messages.BaseChatMessage` + from the team result' :attr:`~agentdhal_agentchat.base.TaskResult.messages` will be published + to other participants in the group chat. + + Args: + participants (List[ChatAgent | Team]): The participants in the group chat, + must have unique names and at least two participants. + model_client (ChatCompletionClient): The ChatCompletion model client used + to select the next speaker. + name (str | None, optional): The name of the group chat, using + :attr:`~agentdhal_agentchat.teams.SelectorGroupChat.DEFAULT_NAME` if not provided. + The name is used by a parent team to identify this group chat so it must + be unique within the parent team. + description (str | None, optional): The description of the group chat, using + :attr:`~agentdhal_agentchat.teams.SelectorGroupChat.DEFAULT_DESCRIPTION` if not provided. + termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None. + Without a termination condition, the group chat will run indefinitely. + max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit. + selector_prompt (str, optional): The prompt template to use for selecting the next speaker. + Available fields: '{roles}', '{participants}', and '{history}'. + `{participants}` is the names of candidates for selection. The format is `["", "", ...]`. + `{roles}` is a newline-separated list of names and descriptions of the candidate agents. The format for each line is: `" : "`. + `{history}` is the conversation history formatted as a double newline separated of names and message content. The format for each message is: `" : "`. + allow_repeated_speaker (bool, optional): Whether to include the previous speaker in the list of candidates to be selected for the next turn. + Defaults to False. The model may still select the previous speaker -- a warning will be logged if this happens. + max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3. + If the model fails to select a speaker after the maximum number of attempts, the previous speaker will be used if available, + otherwise the first participant will be used. + selector_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]], optional): A custom selector + function that takes the conversation history and returns the name of the next speaker. + If provided, this function will be used to override the model to select the next speaker. + If the function returns None, the model will be used to select the next speaker. + NOTE: `selector_func` is not serializable and will be ignored during serialization and deserialization process. + candidate_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]], optional): + A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker + selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`. + This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set. + custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat. + If you are using custom message types or your agents produces custom message types, you need to specify them here. + Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`. + emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False. + model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False. + model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving + :class:`~agentdhal_core.models.LLMMessage`. It can be preloaded with initial messages. Messages stored in model context will be used for speaker selection. The initial messages will be cleared when the team is reset. + + Raises: + ValueError: If the number of participants is less than two or if the selector prompt is invalid. + + Examples: + + A team with multiple participants: + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.teams import SelectorGroupChat + from agentdhal_agentchat.conditions import TextMentionTermination + from agentdhal_agentchat.ui import Console + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + async def lookup_hotel(location: str) -> str: + return f"Here are some hotels in {location}: hotel1, hotel2, hotel3." + + async def lookup_flight(origin: str, destination: str) -> str: + return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3." + + async def book_trip() -> str: + return "Your trip is booked!" + + travel_advisor = AssistantAgent( + "Travel_Advisor", + model_client, + tools=[book_trip], + description="Helps with travel planning.", + ) + hotel_agent = AssistantAgent( + "Hotel_Agent", + model_client, + tools=[lookup_hotel], + description="Helps with hotel booking.", + ) + flight_agent = AssistantAgent( + "Flight_Agent", + model_client, + tools=[lookup_flight], + description="Helps with flight booking.", + ) + termination = TextMentionTermination("TERMINATE") + team = SelectorGroupChat( + [travel_advisor, hotel_agent, flight_agent], + model_client=model_client, + termination_condition=termination, + ) + await Console(team.run_stream(task="Book a 3-day trip to new york.")) + + + asyncio.run(main()) + + A team with a custom selector function: + + .. code-block:: python + + import asyncio + from typing import Sequence + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.teams import SelectorGroupChat + from agentdhal_agentchat.conditions import TextMentionTermination + from agentdhal_agentchat.ui import Console + from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + def check_calculation(x: int, y: int, answer: int) -> str: + if x + y == answer: + return "Correct!" + else: + return "Incorrect!" + + agent1 = AssistantAgent( + "Agent1", + model_client, + description="For calculation", + system_message="Calculate the sum of two numbers", + ) + agent2 = AssistantAgent( + "Agent2", + model_client, + tools=[check_calculation], + description="For checking calculation", + system_message="Check the answer and respond with 'Correct!' or 'Incorrect!'", + ) + + def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None: + if len(messages) == 1 or messages[-1].to_text() == "Incorrect!": + return "Agent1" + if messages[-1].source == "Agent1": + return "Agent2" + return None + + termination = TextMentionTermination("Correct!") + team = SelectorGroupChat( + [agent1, agent2], + model_client=model_client, + selector_func=selector_func, + termination_condition=termination, + ) + + await Console(team.run_stream(task="What is 1 + 1?")) + + + asyncio.run(main()) + + A team with custom model context: + + .. code-block:: python + + import asyncio + + from agentdhal_core.model_context import BufferedChatCompletionContext + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import TextMentionTermination + from agentdhal_agentchat.teams import SelectorGroupChat + from agentdhal_agentchat.ui import Console + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + model_context = BufferedChatCompletionContext(buffer_size=5) + + async def lookup_hotel(location: str) -> str: + return f"Here are some hotels in {location}: hotel1, hotel2, hotel3." + + async def lookup_flight(origin: str, destination: str) -> str: + return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3." + + async def book_trip() -> str: + return "Your trip is booked!" + + travel_advisor = AssistantAgent( + "Travel_Advisor", + model_client, + tools=[book_trip], + description="Helps with travel planning.", + ) + hotel_agent = AssistantAgent( + "Hotel_Agent", + model_client, + tools=[lookup_hotel], + description="Helps with hotel booking.", + ) + flight_agent = AssistantAgent( + "Flight_Agent", + model_client, + tools=[lookup_flight], + description="Helps with flight booking.", + ) + termination = TextMentionTermination("TERMINATE") + team = SelectorGroupChat( + [travel_advisor, hotel_agent, flight_agent], + model_client=model_client, + termination_condition=termination, + model_context=model_context, + ) + await Console(team.run_stream(task="Book a 3-day trip to new york.")) + + + asyncio.run(main()) + """ + + component_config_schema = SelectorGroupChatConfig + component_provider_override = "agentdhal_agentchat.teams.SelectorGroupChat" + + DEFAULT_NAME = "SelectorGroupChat" + DEFAULT_DESCRIPTION = "A team of agents." + + def __init__( + self, + participants: List[ChatAgent | Team], + model_client: ChatCompletionClient, + *, + name: str | None = None, + description: str | None = None, + termination_condition: TerminationCondition | None = None, + max_turns: int | None = None, + runtime: AgentRuntime | None = None, + selector_prompt: str = """You are in a role play game. The following roles are available: +{roles}. +Read the following conversation. Then select the next role from {participants} to play. Only return the role. + +{history} + +Read the above conversation. Then select the next role from {participants} to play. Only return the role. +""", + allow_repeated_speaker: bool = False, + max_selector_attempts: int = 3, + selector_func: Optional[SelectorFuncType] = None, + candidate_func: Optional[CandidateFuncType] = None, + custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, + emit_team_events: bool = False, + model_client_streaming: bool = False, + model_context: ChatCompletionContext | None = None, + ): + super().__init__( + name=name or self.DEFAULT_NAME, + description=description or self.DEFAULT_DESCRIPTION, + participants=participants, + group_chat_manager_name="SelectorGroupChatManager", + group_chat_manager_class=SelectorGroupChatManager, + termination_condition=termination_condition, + max_turns=max_turns, + runtime=runtime, + custom_message_types=custom_message_types, + emit_team_events=emit_team_events, + ) + # Validate the participants. + if len(participants) < 2: + raise ValueError("At least two participants are required for SelectorGroupChat.") + self._selector_prompt = selector_prompt + self._model_client = model_client + self._allow_repeated_speaker = allow_repeated_speaker + self._selector_func = selector_func + self._max_selector_attempts = max_selector_attempts + self._candidate_func = candidate_func + self._model_client_streaming = model_client_streaming + self._model_context = model_context + + def _create_group_chat_manager_factory( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + max_turns: int | None, + message_factory: MessageFactory, + ) -> Callable[[], BaseGroupChatManager]: + return lambda: SelectorGroupChatManager( + name, + group_topic_type, + output_topic_type, + participant_topic_types, + participant_names, + participant_descriptions, + output_message_queue, + termination_condition, + max_turns, + message_factory, + self._model_client, + self._selector_prompt, + self._allow_repeated_speaker, + self._selector_func, + self._max_selector_attempts, + self._candidate_func, + self._emit_team_events, + self._model_context, + self._model_client_streaming, + ) + + def _to_config(self) -> SelectorGroupChatConfig: + return SelectorGroupChatConfig( + name=self._name, + description=self._description, + participants=[participant.dump_component() for participant in self._participants], + model_client=self._model_client.dump_component(), + termination_condition=self._termination_condition.dump_component() if self._termination_condition else None, + max_turns=self._max_turns, + selector_prompt=self._selector_prompt, + allow_repeated_speaker=self._allow_repeated_speaker, + max_selector_attempts=self._max_selector_attempts, + # selector_func=self._selector_func.dump_component() if self._selector_func else None, + emit_team_events=self._emit_team_events, + model_client_streaming=self._model_client_streaming, + model_context=self._model_context.dump_component() if self._model_context else None, + ) + + @classmethod + def _from_config(cls, config: SelectorGroupChatConfig) -> Self: + participants: List[ChatAgent | Team] = [] + for participant in config.participants: + if participant.component_type == ChatAgent.component_type: + participants.append(ChatAgent.load_component(participant)) + elif participant.component_type == Team.component_type: + participants.append(Team.load_component(participant)) + else: + raise ValueError( + f"Invalid participant component type: {participant.component_type}. " "Expected ChatAgent or Team." + ) + return cls( + participants=participants, + model_client=ChatCompletionClient.load_component(config.model_client), + name=config.name, + description=config.description, + termination_condition=TerminationCondition.load_component(config.termination_condition) + if config.termination_condition + else None, + max_turns=config.max_turns, + selector_prompt=config.selector_prompt, + allow_repeated_speaker=config.allow_repeated_speaker, + max_selector_attempts=config.max_selector_attempts, + # selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None]) + # if config.selector_func + # else None, + emit_team_events=config.emit_team_events, + model_client_streaming=config.model_client_streaming, + model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, + ) diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_sequential_routed_agent.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_sequential_routed_agent.py new file mode 100644 index 0000000..fd66703 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_sequential_routed_agent.py @@ -0,0 +1,72 @@ +import asyncio +from typing import Any, Sequence + +from agentdhal_core import MessageContext, RoutedAgent + + +class FIFOLock: + """A lock that ensures coroutines acquire the lock in the order they request it.""" + + def __init__(self) -> None: + self._queue = asyncio.Queue[asyncio.Event]() + self._locked = False + + async def acquire(self) -> None: + # If the lock is not held by any coroutine, set the lock to be held + # by the current coroutine. + if not self._locked: + self._locked = True + return + + # If the lock is held by another coroutine, create an event and put it + # in the queue. Wait for the event to be set. + event = asyncio.Event() + await self._queue.put(event) + await event.wait() + + def release(self) -> None: + if not self._queue.empty(): + # If there are events in the queue, get the next event and set it. + next_event = self._queue.get_nowait() + next_event.set() + else: + # If there are no events in the queue, release the lock. + self._locked = False + + +class SequentialRoutedAgent(RoutedAgent): + """A subclass of :class:`agentdhal_core.RoutedAgent` that ensures + that messages of certain types are processed sequentially + using a FIFO lock. + + This is useful for agents that need to maintain a strict order of + processing messages, such as in a group chat scenario. + + + + Args: + + description (str): The description of the agent. + sequential_message_types (Sequence[Type[Any]]): A sequence of message types that should be + processed sequentially. If a message of one of these types is received, + the agent will acquire a FIFO lock to ensure that it is processed + before any later messages that are also one of these types. + """ + + def __init__(self, description: str, sequential_message_types: Sequence[type[Any]]) -> None: + super().__init__(description=description) + self._fifo_lock = FIFOLock() + self._sequential_message_types = sequential_message_types + + async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None: + if any(isinstance(message, sequential_type) for sequential_type in self._sequential_message_types): + # Acquire the FIFO lock to ensure that this message is processed + # in the order it was received. + await self._fifo_lock.acquire() + try: + return await super().on_message_impl(message, ctx) + finally: + # Release the FIFO lock to allow the next message to be processed. + self._fifo_lock.release() + # If the message is not of a sequential type, process it normally. + return await super().on_message_impl(message, ctx) diff --git a/agent_dhal/agentdhal_agentchat/teams/_group_chat/_swarm_group_chat.py b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_swarm_group_chat.py new file mode 100644 index 0000000..fe8eef2 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -0,0 +1,321 @@ +import asyncio +from typing import Any, Callable, List, Mapping, Sequence + +from agentdhal_core import AgentRuntime, Component, ComponentModel +from pydantic import BaseModel + +from ...base import ChatAgent, TerminationCondition +from ...messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, MessageFactory +from ...state import SwarmManagerState +from ._base_group_chat import BaseGroupChat +from ._base_group_chat_manager import BaseGroupChatManager +from ._events import GroupChatTermination + + +class SwarmGroupChatManager(BaseGroupChatManager): + """A group chat manager that selects the next speaker based on handoff message only.""" + + def __init__( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + max_turns: int | None, + message_factory: MessageFactory, + emit_team_events: bool, + ) -> None: + super().__init__( + name, + group_topic_type, + output_topic_type, + participant_topic_types, + participant_names, + participant_descriptions, + output_message_queue, + termination_condition, + max_turns, + message_factory, + emit_team_events, + ) + self._current_speaker = self._participant_names[0] + + async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: + """Validate the start messages for the group chat.""" + # Check if any of the start messages is a handoff message. + if messages: + for message in messages: + if isinstance(message, HandoffMessage): + if message.target not in self._participant_names: + raise ValueError( + f"The target {message.target} is not one of the participants {self._participant_names}. " + "If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target." + ) + return + + # Check if there is a handoff message in the thread that is not targeting a valid participant. + for existing_message in reversed(self._message_thread): + if isinstance(existing_message, HandoffMessage): + if existing_message.target not in self._participant_names: + raise ValueError( + f"The existing handoff target {existing_message.target} is not one of the participants {self._participant_names}. " + "If you are resuming Swarm with a new task make sure to include in your task " + "a HandoffMessage with a valid participant as the target. For example, if you are " + "resuming from a HandoffTermination, make sure the new task is a HandoffMessage " + "with a valid participant as the target." + ) + # The latest handoff message should always target a valid participant. + # Do not look past the latest handoff message. + return + + async def reset(self) -> None: + self._current_turn = 0 + self._message_thread.clear() + if self._termination_condition is not None: + await self._termination_condition.reset() + self._current_speaker = self._participant_names[0] + + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: + """Select a speaker from the participants based on handoff message. + Looks for the last handoff message in the thread to determine the next speaker. + + .. note:: + + This method always returns a single speaker. + """ + if len(thread) == 0: + return [self._current_speaker] + for message in reversed(thread): + if isinstance(message, HandoffMessage): + self._current_speaker = message.target + # The latest handoff message should always target a valid participant. + assert self._current_speaker in self._participant_names + return [self._current_speaker] + return self._current_speaker + + async def save_state(self) -> Mapping[str, Any]: + state = SwarmManagerState( + message_thread=[msg.dump() for msg in self._message_thread], + current_turn=self._current_turn, + current_speaker=self._current_speaker, + ) + return state.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + swarm_state = SwarmManagerState.model_validate(state) + self._message_thread = [self._message_factory.create(message) for message in swarm_state.message_thread] + self._current_turn = swarm_state.current_turn + self._current_speaker = swarm_state.current_speaker + + +class SwarmConfig(BaseModel): + """The declarative configuration for Swarm.""" + + name: str | None = None + description: str | None = None + participants: List[ComponentModel] + termination_condition: ComponentModel | None = None + max_turns: int | None = None + emit_team_events: bool = False + + +class Swarm(BaseGroupChat, Component[SwarmConfig]): + """A group chat team that selects the next speaker based on handoff message only. + + The first participant in the list of participants is the initial speaker. + The next speaker is selected based on the :class:`~agentdhal_agentchat.messages.HandoffMessage` message + sent by the current speaker. If no handoff message is sent, the current speaker + continues to be the speaker. + + .. note:: + + Unlike :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` and + :class:`~agentdhal_agentchat.teams.SelectorGroupChat`, this group chat + team does not support inner teams as participants. + + Args: + participants (List[ChatAgent]): The agents participating in the group chat. The first agent in the list is the initial speaker. + name (str | None, optional): The name of the group chat, using :attr:`~agentdhal_agentchat.teams.Swarm.DEFAULT_NAME` if not provided. + The name is used by a parent team to identify this group chat so it must be unique within the parent team. + description (str | None, optional): The description of the group chat, using :attr:`~agentdhal_agentchat.teams.Swarm.DEFAULT_DESCRIPTION` if not provided. + termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None. + Without a termination condition, the group chat will run indefinitely. + max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit. + custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat. + If you are using custom message types or your agents produces custom message types, you need to specify them here. + Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`. + emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False. + + Basic example: + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.teams import Swarm + from agentdhal_agentchat.conditions import MaxMessageTermination + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent( + "Alice", + model_client=model_client, + handoffs=["Bob"], + system_message="You are Alice and you only answer questions about yourself.", + ) + agent2 = AssistantAgent( + "Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January." + ) + + termination = MaxMessageTermination(3) + team = Swarm([agent1, agent2], termination_condition=termination) + + stream = team.run_stream(task="What is bob's birthday?") + async for message in stream: + print(message) + + + asyncio.run(main()) + + + Using the :class:`~agentdhal_agentchat.conditions.HandoffTermination` for human-in-the-loop handoff: + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.teams import Swarm + from agentdhal_agentchat.conditions import HandoffTermination, MaxMessageTermination + from agentdhal_agentchat.ui import Console + from agentdhal_agentchat.messages import HandoffMessage + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent = AssistantAgent( + "Alice", + model_client=model_client, + handoffs=["user"], + system_message="You are Alice and you only answer questions about yourself, ask the user for help if needed.", + ) + termination = HandoffTermination(target="user") | MaxMessageTermination(3) + team = Swarm([agent], termination_condition=termination) + + # Start the conversation. + await Console(team.run_stream(task="What is bob's birthday?")) + + # Resume with user feedback. + await Console( + team.run_stream( + task=HandoffMessage(source="user", target="Alice", content="Bob's birthday is on 1st January.") + ) + ) + + + asyncio.run(main()) + """ + + component_config_schema = SwarmConfig + component_provider_override = "agentdhal_agentchat.teams.Swarm" + + DEFAULT_NAME = "Swarm" + DEFAULT_DESCRIPTION = "A team of agents." + + def __init__( + self, + participants: List[ChatAgent], + *, + name: str | None = None, + description: str | None = None, + termination_condition: TerminationCondition | None = None, + max_turns: int | None = None, + runtime: AgentRuntime | None = None, + custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, + emit_team_events: bool = False, + ) -> None: + for participant in participants: + if not isinstance(participant, ChatAgent): + raise TypeError(f"Participant {participant} must be a ChatAgent.") + super().__init__( + name=name or self.DEFAULT_NAME, + description=description or self.DEFAULT_DESCRIPTION, + participants=[participant for participant in participants], + group_chat_manager_name="SwarmGroupChatManager", + group_chat_manager_class=SwarmGroupChatManager, + termination_condition=termination_condition, + max_turns=max_turns, + runtime=runtime, + custom_message_types=custom_message_types, + emit_team_events=emit_team_events, + ) + # The first participant must be able to produce handoff messages. + first_participant = self._participants[0] + assert isinstance(first_participant, ChatAgent) + if HandoffMessage not in first_participant.produced_message_types: + raise ValueError("The first participant must be able to produce a handoff messages.") + + def _create_group_chat_manager_factory( + self, + name: str, + group_topic_type: str, + output_topic_type: str, + participant_topic_types: List[str], + participant_names: List[str], + participant_descriptions: List[str], + output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], + termination_condition: TerminationCondition | None, + max_turns: int | None, + message_factory: MessageFactory, + ) -> Callable[[], SwarmGroupChatManager]: + def _factory() -> SwarmGroupChatManager: + return SwarmGroupChatManager( + name, + group_topic_type, + output_topic_type, + participant_topic_types, + participant_names, + participant_descriptions, + output_message_queue, + termination_condition, + max_turns, + message_factory, + self._emit_team_events, + ) + + return _factory + + def _to_config(self) -> SwarmConfig: + participants = [participant.dump_component() for participant in self._participants] + termination_condition = self._termination_condition.dump_component() if self._termination_condition else None + return SwarmConfig( + name=self._name, + description=self._description, + participants=participants, + termination_condition=termination_condition, + max_turns=self._max_turns, + emit_team_events=self._emit_team_events, + ) + + @classmethod + def _from_config(cls, config: SwarmConfig) -> "Swarm": + participants = [ChatAgent.load_component(participant) for participant in config.participants] + termination_condition = ( + TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None + ) + return cls( + participants, + name=config.name, + description=config.description, + termination_condition=termination_condition, + max_turns=config.max_turns, + emit_team_events=config.emit_team_events, + ) diff --git a/agent_dhal/agentdhal_agentchat/tools/__init__.py b/agent_dhal/agentdhal_agentchat/tools/__init__.py new file mode 100644 index 0000000..9884ddc --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/tools/__init__.py @@ -0,0 +1,4 @@ +from ._agent import AgentTool +from ._team import TeamTool + +__all__ = ["AgentTool", "TeamTool"] diff --git a/agent_dhal/agentdhal_agentchat/tools/_agent.py b/agent_dhal/agentdhal_agentchat/tools/_agent.py new file mode 100644 index 0000000..2b7a257 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/tools/_agent.py @@ -0,0 +1,93 @@ +from agentdhal_core import Component, ComponentModel +from pydantic import BaseModel +from typing_extensions import Self + +from agentdhal_agentchat.agents import BaseChatAgent + +from ._task_runner_tool import TaskRunnerTool + + +class AgentToolConfig(BaseModel): + """Configuration for the AgentTool.""" + + agent: ComponentModel + """The agent to be used for running the task.""" + + return_value_as_last_message: bool = False + """Whether to return the value as the last message of the task result.""" + + +class AgentTool(TaskRunnerTool, Component[AgentToolConfig]): + """Tool that can be used to run a task using an agent. + + The tool returns the result of the task execution as a :class:`~agentdhal_agentchat.base.TaskResult` object. + + .. important:: + When using AgentTool, you **must** disable parallel tool calls in the model client configuration + to avoid concurrency issues. Agents cannot run concurrently as they maintain internal state + that would conflict with parallel execution. For example, set ``parallel_tool_calls=False`` + for :class:`~agentdhal_extensions.models.openai.OpenAIChatCompletionClient` and + :class:`~agentdhal_extensions.models.openai.AzureOpenAIChatCompletionClient`. + + Args: + agent (BaseChatAgent): The agent to be used for running the task. + return_value_as_last_message (bool): Whether to use the last message content of the task result + as the return value of the tool in :meth:`~agentdhal_agentchat.tools.TaskRunnerTool.return_value_as_string`. + If set to True, the last message content will be returned as a string. + If set to False, the tool will return all messages in the task result as a string concatenated together, + with each message prefixed by its source (e.g., "writer: ...", "assistant: ..."). + + Example: + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.tools import AgentTool + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4.1") + writer = AssistantAgent( + name="writer", + description="A writer agent for generating text.", + model_client=model_client, + system_message="Write well.", + ) + writer_tool = AgentTool(agent=writer) + + # Create model client with parallel tool calls disabled for the main agent + main_model_client = OpenAIChatCompletionClient(model="gpt-4.1", parallel_tool_calls=False) + assistant = AssistantAgent( + name="assistant", + model_client=main_model_client, + tools=[writer_tool], + system_message="You are a helpful assistant.", + ) + await Console(assistant.run_stream(task="Write a poem about the sea.")) + + + asyncio.run(main()) + """ + + component_config_schema = AgentToolConfig + component_provider_override = "agentdhal_agentchat.tools.AgentTool" + + def __init__(self, agent: BaseChatAgent, return_value_as_last_message: bool = False) -> None: + self._agent = agent + super().__init__( + agent, agent.name, agent.description, return_value_as_last_message=return_value_as_last_message + ) + + def _to_config(self) -> AgentToolConfig: + return AgentToolConfig( + agent=self._agent.dump_component(), + return_value_as_last_message=self._return_value_as_last_message, + ) + + @classmethod + def _from_config(cls, config: AgentToolConfig) -> Self: + return cls(BaseChatAgent.load_component(config.agent), config.return_value_as_last_message) diff --git a/agent_dhal/agentdhal_agentchat/tools/_task_runner_tool.py b/agent_dhal/agentdhal_agentchat/tools/_task_runner_tool.py new file mode 100644 index 0000000..65380b1 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/tools/_task_runner_tool.py @@ -0,0 +1,72 @@ +from abc import ABC +from typing import Annotated, Any, AsyncGenerator, List, Mapping + +from agentdhal_core import CancellationToken +from agentdhal_core.tools import BaseStreamTool +from pydantic import BaseModel + +from ..agents import BaseChatAgent +from ..base import TaskResult +from ..messages import BaseAgentEvent, BaseChatMessage +from ..teams import BaseGroupChat + + +class TaskRunnerToolArgs(BaseModel): + """Input for the TaskRunnerTool.""" + + task: Annotated[str, "The task to be executed."] + + +class TaskRunnerTool(BaseStreamTool[TaskRunnerToolArgs, BaseAgentEvent | BaseChatMessage, TaskResult], ABC): + """An base class for tool that can be used to run a task using a team or an agent.""" + + component_type = "tool" + + def __init__( + self, + task_runner: BaseGroupChat | BaseChatAgent, + name: str, + description: str, + return_value_as_last_message: bool, + ) -> None: + self._task_runner = task_runner + self._return_value_as_last_message = return_value_as_last_message + super().__init__( + args_type=TaskRunnerToolArgs, + return_type=TaskResult, + name=name, + description=description, + strict=True, + ) + + async def run(self, args: TaskRunnerToolArgs, cancellation_token: CancellationToken) -> TaskResult: + """Run the task and return the result.""" + return await self._task_runner.run(task=args.task, cancellation_token=cancellation_token) + + async def run_stream( + self, args: TaskRunnerToolArgs, cancellation_token: CancellationToken + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]: + """Run the task and yield events or messages as they are produced, the final :class:`TaskResult` + will be yielded at the end.""" + async for event in self._task_runner.run_stream(task=args.task, cancellation_token=cancellation_token): + yield event + + def return_value_as_string(self, value: TaskResult) -> str: + """Convert the task result to a string.""" + if self._return_value_as_last_message: + if value.messages and isinstance(value.messages[-1], BaseChatMessage): + return value.messages[-1].to_model_text() + raise ValueError("The last message is not a BaseChatMessage.") + parts: List[str] = [] + for message in value.messages: + if isinstance(message, BaseChatMessage): + if message.source == "user": + continue + parts.append(f"{message.source}: {message.to_model_text()}") + return "\n\n".join(parts) + + async def save_state_json(self) -> Mapping[str, Any]: + return await self._task_runner.save_state() + + async def load_state_json(self, state: Mapping[str, Any]) -> None: + await self._task_runner.load_state(state) diff --git a/agent_dhal/agentdhal_agentchat/tools/_team.py b/agent_dhal/agentdhal_agentchat/tools/_team.py new file mode 100644 index 0000000..31afd77 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/tools/_team.py @@ -0,0 +1,133 @@ +from agentdhal_core import Component, ComponentModel +from pydantic import BaseModel +from typing_extensions import Self + +from agentdhal_agentchat.teams import BaseGroupChat + +from ._task_runner_tool import TaskRunnerTool + + +class TeamToolConfig(BaseModel): + """Configuration for the TeamTool.""" + + name: str + """The name of the tool.""" + description: str + """The name and description of the tool.""" + team: ComponentModel + """The team to be used for running the task.""" + return_value_as_last_message: bool = False + """Whether to return the value as the last message of the task result.""" + + +class TeamTool(TaskRunnerTool, Component[TeamToolConfig]): + """Tool that can be used to run a task. + + The tool returns the result of the task execution as a :class:`~agentdhal_agentchat.base.TaskResult` object. + + .. important:: + When using TeamTool, you **must** disable parallel tool calls in the model client configuration + to avoid concurrency issues. Teams cannot run concurrently as they maintain internal state + that would conflict with parallel execution. For example, set ``parallel_tool_calls=False`` + for :class:`~agentdhal_extensions.models.openai.OpenAIChatCompletionClient` and + :class:`~agentdhal_extensions.models.openai.AzureOpenAIChatCompletionClient`. + + Args: + team (BaseGroupChat): The team to be used for running the task. + name (str): The name of the tool. + description (str): The description of the tool. + return_value_as_last_message (bool): Whether to use the last message content of the task result + as the return value of the tool in :meth:`~agentdhal_agentchat.tools.TaskRunnerTool.return_value_as_string`. + If set to True, the last message content will be returned as a string. + If set to False, the tool will return all messages in the task result as a string concatenated together, + with each message prefixed by its source (e.g., "writer: ...", "assistant: ..."). + + Example: + + .. code-block:: python + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import SourceMatchTermination + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_agentchat.tools import TeamTool + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main() -> None: + # Disable parallel tool calls when using TeamTool + model_client = OpenAIChatCompletionClient(model="gpt-4.1") + + writer = AssistantAgent(name="writer", model_client=model_client, system_message="You are a helpful assistant.") + reviewer = AssistantAgent( + name="reviewer", model_client=model_client, system_message="You are a critical reviewer." + ) + summarizer = AssistantAgent( + name="summarizer", + model_client=model_client, + system_message="You combine the review and produce a revised response.", + ) + team = RoundRobinGroupChat( + [writer, reviewer, summarizer], termination_condition=SourceMatchTermination(sources=["summarizer"]) + ) + + # Create a TeamTool that uses the team to run tasks, returning the last message as the result. + tool = TeamTool( + team=team, + name="writing_team", + description="A tool for writing tasks.", + return_value_as_last_message=True, + ) + + # Create model client with parallel tool calls disabled for the main agent + main_model_client = OpenAIChatCompletionClient(model="gpt-4.1", parallel_tool_calls=False) + main_agent = AssistantAgent( + name="main_agent", + model_client=main_model_client, + system_message="You are a helpful assistant that can use the writing tool.", + tools=[tool], + ) + # For handling each events manually. + # async for message in main_agent.run_stream( + # task="Write a short story about a robot learning to love.", + # ): + # print(message) + # Use Console to display the messages in a more readable format. + await Console( + main_agent.run_stream( + task="Write a short story about a robot learning to love.", + ) + ) + + + if __name__ == "__main__": + import asyncio + + asyncio.run(main()) + """ + + component_config_schema = TeamToolConfig + component_provider_override = "agentdhal_agentchat.tools.TeamTool" + + def __init__( + self, team: BaseGroupChat, name: str, description: str, return_value_as_last_message: bool = False + ) -> None: + self._team = team + super().__init__(team, name, description, return_value_as_last_message=return_value_as_last_message) + + def _to_config(self) -> TeamToolConfig: + return TeamToolConfig( + name=self._name, + description=self._description, + team=self._team.dump_component(), + return_value_as_last_message=self._return_value_as_last_message, + ) + + @classmethod + def _from_config(cls, config: TeamToolConfig) -> Self: + return cls( + BaseGroupChat.load_component(config.team), + config.name, + config.description, + config.return_value_as_last_message, + ) diff --git a/agent_dhal/agentdhal_agentchat/ui/__init__.py b/agent_dhal/agentdhal_agentchat/ui/__init__.py new file mode 100644 index 0000000..9cc0837 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/ui/__init__.py @@ -0,0 +1,7 @@ +""" +This module implements utility classes for formatting/printing agent messages. +""" + +from ._console import Console, UserInputManager + +__all__ = ["Console", "UserInputManager"] diff --git a/agent_dhal/agentdhal_agentchat/ui/_console.py b/agent_dhal/agentdhal_agentchat/ui/_console.py new file mode 100644 index 0000000..ab70f7e --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/ui/_console.py @@ -0,0 +1,204 @@ +import asyncio +import os +import sys +import time +from inspect import iscoroutinefunction +from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, cast + +from agentdhal_core import CancellationToken +from agentdhal_core.models import RequestUsage + +from agentdhal_agentchat.agents import UserProxyAgent +from agentdhal_agentchat.base import Response, TaskResult +from agentdhal_agentchat.messages import ( + BaseAgentEvent, + BaseChatMessage, + ModelClientStreamingChunkEvent, + MultiModalMessage, + UserInputRequestedEvent, +) + + +def _is_running_in_iterm() -> bool: + return os.getenv("TERM_PROGRAM") == "iTerm.app" + + +def _is_output_a_tty() -> bool: + return sys.stdout.isatty() + + +SyncInputFunc = Callable[[str], str] +AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]] +InputFuncType = Union[SyncInputFunc, AsyncInputFunc] + +T = TypeVar("T", bound=TaskResult | Response) + + +class UserInputManager: + def __init__(self, callback: InputFuncType): + self.input_events: Dict[str, asyncio.Event] = {} + self.callback = callback + + def get_wrapped_callback(self) -> AsyncInputFunc: + async def user_input_func_wrapper(prompt: str, cancellation_token: Optional[CancellationToken]) -> str: + # Lookup the event for the prompt, if it exists wait for it. + # If it doesn't exist, create it and store it. + # Get request ID: + request_id = UserProxyAgent.InputRequestContext.request_id() + if request_id in self.input_events: + event = self.input_events[request_id] + else: + event = asyncio.Event() + self.input_events[request_id] = event + + await event.wait() + + del self.input_events[request_id] + + if iscoroutinefunction(self.callback): + # Cast to AsyncInputFunc for proper typing + async_func = cast(AsyncInputFunc, self.callback) + return await async_func(prompt, cancellation_token) + else: + # Cast to SyncInputFunc for proper typing + sync_func = cast(SyncInputFunc, self.callback) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, sync_func, prompt) + + return user_input_func_wrapper + + def notify_event_received(self, request_id: str) -> None: + if request_id in self.input_events: + self.input_events[request_id].set() + else: + event = asyncio.Event() + self.input_events[request_id] = event + + +def aprint(output: str, end: str = "\n", flush: bool = False) -> Awaitable[None]: + return asyncio.to_thread(print, output, end=end, flush=flush) + + +async def Console( + stream: AsyncGenerator[BaseAgentEvent | BaseChatMessage | T, None], + *, + no_inline_images: bool = False, + output_stats: bool = False, + user_input_manager: UserInputManager | None = None, +) -> T: + """ + Consumes the message stream from :meth:`~agentdhal_agentchat.base.TaskRunner.run_stream` + or :meth:`~agentdhal_agentchat.base.ChatAgent.on_messages_stream` and renders the messages to the console. + Returns the last processed TaskResult or Response. + + .. note:: + + `output_stats` is experimental and the stats may not be accurate. + It will be improved in future releases. + + Args: + stream (AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None] | AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]): Message stream to render. + This can be from :meth:`~agentdhal_agentchat.base.TaskRunner.run_stream` or :meth:`~agentdhal_agentchat.base.ChatAgent.on_messages_stream`. + no_inline_images (bool, optional): If terminal is iTerm2 will render images inline. Use this to disable this behavior. Defaults to False. + output_stats (bool, optional): (Experimental) If True, will output a summary of the messages and inline token usage info. Defaults to False. + + Returns: + last_processed: A :class:`~agentdhal_agentchat.base.TaskResult` if the stream is from :meth:`~agentdhal_agentchat.base.TaskRunner.run_stream` + or a :class:`~agentdhal_agentchat.base.Response` if the stream is from :meth:`~agentdhal_agentchat.base.ChatAgent.on_messages_stream`. + """ + render_image_iterm = _is_running_in_iterm() and _is_output_a_tty() and not no_inline_images + start_time = time.time() + total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + + last_processed: Optional[T] = None + + streaming_chunks: List[str] = [] + + async for message in stream: + if isinstance(message, TaskResult): + duration = time.time() - start_time + if output_stats: + output = ( + f"{'-' * 10} Summary {'-' * 10}\n" + f"Number of messages: {len(message.messages)}\n" + f"Finish reason: {message.stop_reason}\n" + f"Total prompt tokens: {total_usage.prompt_tokens}\n" + f"Total completion tokens: {total_usage.completion_tokens}\n" + f"Duration: {duration:.2f} seconds\n" + ) + await aprint(output, end="", flush=True) + + # mypy ignore + last_processed = message # type: ignore + + elif isinstance(message, Response): + duration = time.time() - start_time + + # Print final response. + if isinstance(message.chat_message, MultiModalMessage): + final_content = message.chat_message.to_text(iterm=render_image_iterm) + else: + final_content = message.chat_message.to_text() + output = f"{'-' * 10} {message.chat_message.source} {'-' * 10}\n{final_content}\n" + if message.chat_message.models_usage: + if output_stats: + output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n" + total_usage.completion_tokens += message.chat_message.models_usage.completion_tokens + total_usage.prompt_tokens += message.chat_message.models_usage.prompt_tokens + await aprint(output, end="", flush=True) + + # Print summary. + if output_stats: + if message.inner_messages is not None: + num_inner_messages = len(message.inner_messages) + else: + num_inner_messages = 0 + output = ( + f"{'-' * 10} Summary {'-' * 10}\n" + f"Number of inner messages: {num_inner_messages}\n" + f"Total prompt tokens: {total_usage.prompt_tokens}\n" + f"Total completion tokens: {total_usage.completion_tokens}\n" + f"Duration: {duration:.2f} seconds\n" + ) + await aprint(output, end="", flush=True) + + # mypy ignore + last_processed = message # type: ignore + # We don't want to print UserInputRequestedEvent messages, we just use them to signal the user input event. + elif isinstance(message, UserInputRequestedEvent): + if user_input_manager is not None: + user_input_manager.notify_event_received(message.request_id) + else: + # Cast required for mypy to be happy + message = cast(BaseAgentEvent | BaseChatMessage, message) # type: ignore + if not streaming_chunks: + # Print message sender. + await aprint( + f"{'-' * 10} {message.__class__.__name__} ({message.source}) {'-' * 10}", end="\n", flush=True + ) + if isinstance(message, ModelClientStreamingChunkEvent): + await aprint(message.to_text(), end="", flush=True) + streaming_chunks.append(message.content) + else: + if streaming_chunks: + streaming_chunks.clear() + # Chunked messages are already printed, so we just print a newline. + await aprint("", end="\n", flush=True) + elif isinstance(message, MultiModalMessage): + await aprint(message.to_text(iterm=render_image_iterm), end="\n", flush=True) + else: + await aprint(message.to_text(), end="\n", flush=True) + if message.models_usage: + if output_stats: + await aprint( + f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]", + end="\n", + flush=True, + ) + total_usage.completion_tokens += message.models_usage.completion_tokens + total_usage.prompt_tokens += message.models_usage.prompt_tokens + + if last_processed is None: + raise ValueError("No TaskResult or Response was processed.") + + return last_processed diff --git a/agent_dhal/agentdhal_agentchat/utils/__init__.py b/agent_dhal/agentdhal_agentchat/utils/__init__.py new file mode 100644 index 0000000..44de85b --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/utils/__init__.py @@ -0,0 +1,7 @@ +""" +This module implements various utilities common to AgentChat agents and teams. +""" + +from ._utils import content_to_str, remove_images + +__all__ = ["content_to_str", "remove_images"] diff --git a/agent_dhal/agentdhal_agentchat/utils/_utils.py b/agent_dhal/agentdhal_agentchat/utils/_utils.py new file mode 100644 index 0000000..39c13b0 --- /dev/null +++ b/agent_dhal/agentdhal_agentchat/utils/_utils.py @@ -0,0 +1,44 @@ +from typing import List, Union + +from agentdhal_core import FunctionCall, Image +from agentdhal_core.models import FunctionExecutionResult, LLMMessage, UserMessage +from pydantic import BaseModel + +# Type aliases for convenience +_StructuredContent = BaseModel +_UserContent = Union[str, List[Union[str, Image]]] +_AssistantContent = Union[str, List[FunctionCall]] +_FunctionExecutionContent = List[FunctionExecutionResult] +_SystemContent = str + + +def content_to_str( + content: _UserContent | _AssistantContent | _FunctionExecutionContent | _SystemContent | _StructuredContent, +) -> str: + """Convert the content of an LLMMessage to a string.""" + if isinstance(content, str): + return content + elif isinstance(content, BaseModel): + return content.model_dump_json() + else: + result: List[str] = [] + for c in content: + if isinstance(c, str): + result.append(c) + elif isinstance(c, Image): + result.append("") + else: + result.append(str(c)) + + return "\n".join(result) + + +def remove_images(messages: List[LLMMessage]) -> List[LLMMessage]: + """Remove images from a list of LLMMessages""" + str_messages: List[LLMMessage] = [] + for message in messages: + if isinstance(message, UserMessage) and isinstance(message.content, list): + str_messages.append(UserMessage(content=content_to_str(message.content), source=message.source)) + else: + str_messages.append(message) + return str_messages diff --git a/agent_dhal/agentdhal_core/__init__.py b/agent_dhal/agentdhal_core/__init__.py new file mode 100644 index 0000000..912b179 --- /dev/null +++ b/agent_dhal/agentdhal_core/__init__.py @@ -0,0 +1,142 @@ +# AgentDhal Core Module - Self-contained version +__version__ = "1.0.0" + +from ._agent import Agent +from ._agent_id import AgentId +from ._agent_instantiation import AgentInstantiationContext +from ._agent_metadata import AgentMetadata +from ._agent_proxy import AgentProxy +from ._agent_runtime import AgentRuntime +from ._agent_type import AgentType +from ._base_agent import BaseAgent +from ._cache_store import CacheStore, InMemoryStore +from ._cancellation_token import CancellationToken +from ._closure_agent import ClosureAgent, ClosureContext +from ._component_config import ( + Component, + ComponentBase, + ComponentFromConfig, + ComponentLoader, + ComponentModel, + ComponentSchemaType, + ComponentToConfig, + ComponentType, + is_component_class, + is_component_instance, +) +from ._constants import ( + EVENT_LOGGER_NAME as EVENT_LOGGER_NAME_ALIAS, +) +from ._constants import ( + ROOT_LOGGER_NAME as ROOT_LOGGER_NAME_ALIAS, +) +from ._constants import ( + TRACE_LOGGER_NAME as TRACE_LOGGER_NAME_ALIAS, +) +from ._default_subscription import DefaultSubscription, default_subscription, type_subscription +from ._default_topic import DefaultTopicId +from ._image import Image +from ._intervention import ( + DefaultInterventionHandler, + DropMessage, + InterventionHandler, +) +from ._message_context import MessageContext +from ._message_handler_context import MessageHandlerContext +from ._routed_agent import RoutedAgent, event, message_handler, rpc +from ._serialization import ( + JSON_DATA_CONTENT_TYPE as JSON_DATA_CONTENT_TYPE_ALIAS, +) +from ._serialization import ( + PROTOBUF_DATA_CONTENT_TYPE as PROTOBUF_DATA_CONTENT_TYPE_ALIAS, +) +from ._serialization import ( + MessageSerializer, + UnknownPayload, + try_get_known_serializers_for_type, +) +from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime +from ._subscription import Subscription +from ._subscription_context import SubscriptionInstantiationContext +from ._telemetry import ( + trace_create_agent_span, + trace_invoke_agent_span, + trace_tool_span, +) +from ._topic import TopicId +from ._type_prefix_subscription import TypePrefixSubscription +from ._type_subscription import TypeSubscription +from ._types import FunctionCall + +EVENT_LOGGER_NAME = EVENT_LOGGER_NAME_ALIAS +"""The name of the logger used for structured events.""" + +ROOT_LOGGER_NAME = ROOT_LOGGER_NAME_ALIAS +"""The name of the root logger.""" + +TRACE_LOGGER_NAME = TRACE_LOGGER_NAME_ALIAS +"""Logger name used for developer intended trace logging. The content and format of this log should not be depended upon.""" + +JSON_DATA_CONTENT_TYPE = JSON_DATA_CONTENT_TYPE_ALIAS +"""The content type for JSON data.""" + +PROTOBUF_DATA_CONTENT_TYPE = PROTOBUF_DATA_CONTENT_TYPE_ALIAS +"""The content type for Protobuf data.""" + +__all__ = [ + "Agent", + "AgentId", + "AgentProxy", + "AgentMetadata", + "AgentRuntime", + "BaseAgent", + "CacheStore", + "InMemoryStore", + "CancellationToken", + "AgentInstantiationContext", + "TopicId", + "Subscription", + "MessageContext", + "AgentType", + "SubscriptionInstantiationContext", + "MessageHandlerContext", + "MessageSerializer", + "try_get_known_serializers_for_type", + "UnknownPayload", + "Image", + "RoutedAgent", + "ClosureAgent", + "ClosureContext", + "message_handler", + "event", + "rpc", + "FunctionCall", + "TypeSubscription", + "DefaultSubscription", + "DefaultTopicId", + "default_subscription", + "type_subscription", + "TypePrefixSubscription", + "JSON_DATA_CONTENT_TYPE", + "PROTOBUF_DATA_CONTENT_TYPE", + "SingleThreadedAgentRuntime", + "ROOT_LOGGER_NAME", + "EVENT_LOGGER_NAME", + "TRACE_LOGGER_NAME", + "Component", + "ComponentBase", + "ComponentFromConfig", + "ComponentLoader", + "ComponentModel", + "ComponentSchemaType", + "ComponentToConfig", + "ComponentType", + "is_component_class", + "is_component_instance", + "DropMessage", + "InterventionHandler", + "DefaultInterventionHandler", + "trace_create_agent_span", + "trace_invoke_agent_span", + "trace_tool_span", +] diff --git a/agent_dhal/agentdhal_core/_agent.py b/agent_dhal/agentdhal_core/_agent.py new file mode 100644 index 0000000..e407fe1 --- /dev/null +++ b/agent_dhal/agentdhal_core/_agent.py @@ -0,0 +1,64 @@ +from typing import TYPE_CHECKING, Any, Mapping, Protocol, runtime_checkable + +from ._agent_id import AgentId +from ._agent_metadata import AgentMetadata +from ._message_context import MessageContext + +# Forward declaration for type checking only +if TYPE_CHECKING: + from ._agent_runtime import AgentRuntime + + +@runtime_checkable +class Agent(Protocol): + @property + def metadata(self) -> AgentMetadata: + """Metadata of the agent.""" + ... + + @property + def id(self) -> AgentId: + """ID of the agent.""" + ... + + async def bind_id_and_runtime(self, id: AgentId, runtime: "AgentRuntime") -> None: + """Function used to bind an Agent instance to an `AgentRuntime`. + + Args: + agent_id (AgentId): ID of the agent. + runtime (AgentRuntime): AgentRuntime instance to bind the agent to. + """ + ... + + async def on_message(self, message: Any, ctx: MessageContext) -> Any: + """Message handler for the agent. This should only be called by the runtime, not by other agents. + + Args: + message (Any): Received message. Type is one of the types in `subscriptions`. + ctx (MessageContext): Context of the message. + + Returns: + Any: Response to the message. Can be None. + + Raises: + asyncio.CancelledError: If the message was cancelled. + CantHandleException: If the agent cannot handle the message. + """ + ... + + async def save_state(self) -> Mapping[str, Any]: + """Save the state of the agent. The result must be JSON serializable.""" + ... + + async def load_state(self, state: Mapping[str, Any]) -> None: + """Load in the state of the agent obtained from `save_state`. + + Args: + state (Mapping[str, Any]): State of the agent. Must be JSON serializable. + """ + + ... + + async def close(self) -> None: + """Called when the runtime is closed""" + ... diff --git a/agent_dhal/agentdhal_core/_agent_id.py b/agent_dhal/agentdhal_core/_agent_id.py new file mode 100644 index 0000000..de04659 --- /dev/null +++ b/agent_dhal/agentdhal_core/_agent_id.py @@ -0,0 +1,68 @@ +import re + +from typing_extensions import Self + +from ._agent_type import AgentType + + +def is_valid_agent_type(value: str) -> bool: + return bool(re.match(r"^[\w\-\.]+\Z", value)) + + +class AgentId: + """ + Agent ID uniquely identifies an agent instance within an agent runtime - including distributed runtime. It is the 'address' of the agent instance for receiving messages. + + See here for more information: :ref:`agentid_and_lifecycle` + """ + + def __init__(self, type: str | AgentType, key: str) -> None: + if isinstance(type, AgentType): + type = type.type + + if not is_valid_agent_type(type): + raise ValueError(rf"Invalid agent type: {type}. Allowed values MUST match the regex: `^[\w\-\.]+\Z`") + + self._type = type + self._key = key + + def __hash__(self) -> int: + return hash((self._type, self._key)) + + def __str__(self) -> str: + return f"{self._type}/{self._key}" + + def __repr__(self) -> str: + return f'AgentId(type="{self._type}", key="{self._key}")' + + def __eq__(self, value: object) -> bool: + if not isinstance(value, AgentId): + return False + return self._type == value.type and self._key == value.key + + @classmethod + def from_str(cls, agent_id: str) -> Self: + """Convert a string of the format ``type/key`` into an AgentId""" + items = agent_id.split("/", maxsplit=1) + if len(items) != 2: + raise ValueError(f"Invalid agent id: {agent_id}") + type, key = items[0], items[1] + return cls(type, key) + + @property + def type(self) -> str: + """ + An identifier that associates an agent with a specific factory function. + + Strings may only be composed of alphanumeric letters (a-z) and (0-9), or underscores (_). + """ + return self._type + + @property + def key(self) -> str: + """ + Agent instance identifier. + + Strings may only be composed of alphanumeric letters (a-z) and (0-9), or underscores (_). + """ + return self._key diff --git a/agent_dhal/agentdhal_core/_agent_instantiation.py b/agent_dhal/agentdhal_core/_agent_instantiation.py new file mode 100644 index 0000000..28f6a0f --- /dev/null +++ b/agent_dhal/agentdhal_core/_agent_instantiation.py @@ -0,0 +1,126 @@ +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, ClassVar, Generator + +from ._agent_id import AgentId +from ._agent_runtime import AgentRuntime + + +class AgentInstantiationContext: + """A static class that provides context for agent instantiation. + + This static class can be used to access the current runtime and agent ID + during agent instantiation -- inside the factory function or the agent's + class constructor. + + Example: + + Get the current runtime and agent ID inside the factory function and + the agent's constructor: + + .. code-block:: python + + import asyncio + from dataclasses import dataclass + + from agentdhal_core import ( + AgentId, + AgentInstantiationContext, + MessageContext, + RoutedAgent, + SingleThreadedAgentRuntime, + message_handler, + ) + + + @dataclass + class TestMessage: + content: str + + + class TestAgent(RoutedAgent): + def __init__(self, description: str): + super().__init__(description) + # Get the current runtime -- we don't use it here, but it's available. + _ = AgentInstantiationContext.current_runtime() + # Get the current agent ID. + agent_id = AgentInstantiationContext.current_agent_id() + print(f"Current AgentID from constructor: {agent_id}") + + @message_handler + async def handle_test_message(self, message: TestMessage, ctx: MessageContext) -> None: + print(f"Received message: {message.content}") + + + def test_agent_factory() -> TestAgent: + # Get the current runtime -- we don't use it here, but it's available. + _ = AgentInstantiationContext.current_runtime() + # Get the current agent ID. + agent_id = AgentInstantiationContext.current_agent_id() + print(f"Current AgentID from factory: {agent_id}") + return TestAgent(description="Test agent") + + + async def main() -> None: + # Create a SingleThreadedAgentRuntime instance. + runtime = SingleThreadedAgentRuntime() + + # Start the runtime. + runtime.start() + + # Register the agent type with a factory function. + await runtime.register_factory("test_agent", test_agent_factory) + + # Send a message to the agent. The runtime will instantiate the agent and call the message handler. + await runtime.send_message(TestMessage(content="Hello, world!"), AgentId("test_agent", "default")) + + # Stop the runtime. + await runtime.stop() + + + asyncio.run(main()) + + """ + + def __init__(self) -> None: + raise RuntimeError( + "AgentInstantiationContext cannot be instantiated. It is a static class that provides context management for agent instantiation." + ) + + _AGENT_INSTANTIATION_CONTEXT_VAR: ClassVar[ContextVar[tuple[AgentRuntime, AgentId]]] = ContextVar( + "_AGENT_INSTANTIATION_CONTEXT_VAR" + ) + + @classmethod + @contextmanager + def populate_context(cls, ctx: tuple[AgentRuntime, AgentId]) -> Generator[None, Any, None]: + """:meta private:""" + token = AgentInstantiationContext._AGENT_INSTANTIATION_CONTEXT_VAR.set(ctx) + try: + yield + finally: + AgentInstantiationContext._AGENT_INSTANTIATION_CONTEXT_VAR.reset(token) + + @classmethod + def current_runtime(cls) -> AgentRuntime: + try: + return cls._AGENT_INSTANTIATION_CONTEXT_VAR.get()[0] + except LookupError as e: + raise RuntimeError( + "AgentInstantiationContext.runtime() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so." + ) from e + + @classmethod + def current_agent_id(cls) -> AgentId: + try: + return cls._AGENT_INSTANTIATION_CONTEXT_VAR.get()[1] + except LookupError as e: + raise RuntimeError( + "AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so." + ) from e + + @classmethod + def is_in_factory_call(cls) -> bool: + if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None: + return False + return True diff --git a/agent_dhal/agentdhal_core/_agent_metadata.py b/agent_dhal/agentdhal_core/_agent_metadata.py new file mode 100644 index 0000000..abdf920 --- /dev/null +++ b/agent_dhal/agentdhal_core/_agent_metadata.py @@ -0,0 +1,7 @@ +from typing import TypedDict + + +class AgentMetadata(TypedDict): + type: str + key: str + description: str diff --git a/agent_dhal/agentdhal_core/_agent_proxy.py b/agent_dhal/agentdhal_core/_agent_proxy.py new file mode 100644 index 0000000..f6ee258 --- /dev/null +++ b/agent_dhal/agentdhal_core/_agent_proxy.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Awaitable, Mapping + +from ._agent_id import AgentId +from ._agent_metadata import AgentMetadata +from ._cancellation_token import CancellationToken + +if TYPE_CHECKING: + from ._agent_runtime import AgentRuntime + + +class AgentProxy: + """A helper class that allows you to use an :class:`~agentdhal_core.AgentId` in place of its associated :class:`~agentdhal_core.Agent`""" + + def __init__(self, agent: AgentId, runtime: AgentRuntime): + self._agent = agent + self._runtime = runtime + + @property + def id(self) -> AgentId: + """Target agent for this proxy""" + return self._agent + + @property + def metadata(self) -> Awaitable[AgentMetadata]: + """Metadata of the agent.""" + return self._runtime.agent_metadata(self._agent) + + async def send_message( + self, + message: Any, + *, + sender: AgentId, + cancellation_token: CancellationToken | None = None, + message_id: str | None = None, + ) -> Any: + return await self._runtime.send_message( + message, + recipient=self._agent, + sender=sender, + cancellation_token=cancellation_token, + message_id=message_id, + ) + + async def save_state(self) -> Mapping[str, Any]: + """Save the state of the agent. The result must be JSON serializable.""" + return await self._runtime.agent_save_state(self._agent) + + async def load_state(self, state: Mapping[str, Any]) -> None: + """Load in the state of the agent obtained from `save_state`. + + Args: + state (Mapping[str, Any]): State of the agent. Must be JSON serializable. + """ + await self._runtime.agent_load_state(self._agent, state) diff --git a/agent_dhal/agentdhal_core/_agent_runtime.py b/agent_dhal/agentdhal_core/_agent_runtime.py new file mode 100644 index 0000000..70f8ef3 --- /dev/null +++ b/agent_dhal/agentdhal_core/_agent_runtime.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable + +from ._agent import Agent +from ._agent_id import AgentId +from ._agent_metadata import AgentMetadata +from ._agent_type import AgentType +from ._cancellation_token import CancellationToken +from ._serialization import MessageSerializer +from ._subscription import Subscription +from ._topic import TopicId + +# Undeliverable - error + +T = TypeVar("T", bound=Agent) + + +@runtime_checkable +class AgentRuntime(Protocol): + async def send_message( + self, + message: Any, + recipient: AgentId, + *, + sender: AgentId | None = None, + cancellation_token: CancellationToken | None = None, + message_id: str | None = None, + ) -> Any: + """Send a message to an agent and get a response. + + Args: + message (Any): The message to send. + recipient (AgentId): The agent to send the message to. + sender (AgentId | None, optional): Agent which sent the message. Should **only** be None if this was sent from no agent, such as directly to the runtime externally. Defaults to None. + cancellation_token (CancellationToken | None, optional): Token used to cancel an in progress . Defaults to None. + + Raises: + CantHandleException: If the recipient cannot handle the message. + UndeliverableException: If the message cannot be delivered. + Other: Any other exception raised by the recipient. + + Returns: + Any: The response from the agent. + """ + + ... + + async def publish_message( + self, + message: Any, + topic_id: TopicId, + *, + sender: AgentId | None = None, + cancellation_token: CancellationToken | None = None, + message_id: str | None = None, + ) -> None: + """Publish a message to all agents in the given namespace, or if no namespace is provided, the namespace of the sender. + + No responses are expected from publishing. + + Args: + message (Any): The message to publish. + topic_id (TopicId): The topic to publish the message to. + sender (AgentId | None, optional): The agent which sent the message. Defaults to None. + cancellation_token (CancellationToken | None, optional): Token used to cancel an in progress. Defaults to None. + message_id (str | None, optional): The message id. If None, a new message id will be generated. Defaults to None. This message id must be unique. and is recommended to be a UUID. + + Raises: + UndeliverableException: If the message cannot be delivered. + """ + ... + + async def register_factory( + self, + type: str | AgentType, + agent_factory: Callable[[], T | Awaitable[T]], + *, + expected_class: type[T] | None = None, + ) -> AgentType: + """Register an agent factory with the runtime associated with a specific type. The type must be unique. This API does not add any subscriptions. + + .. note:: + + This is a low level API and usually the agent class's `register` method should be used instead, as this also handles subscriptions automatically. + + Example: + + .. code-block:: python + + from dataclasses import dataclass + + from agentdhal_core import AgentRuntime, MessageContext, RoutedAgent, event + from agentdhal_core.models import UserMessage + + + @dataclass + class MyMessage: + content: str + + + class MyAgent(RoutedAgent): + def __init__(self) -> None: + super().__init__("My core agent") + + @event + async def handler(self, message: UserMessage, context: MessageContext) -> None: + print("Event received: ", message.content) + + + async def my_agent_factory(): + return MyAgent() + + + async def main() -> None: + runtime: AgentRuntime = ... # type: ignore + await runtime.register_factory("my_agent", lambda: MyAgent()) + + + import asyncio + + asyncio.run(main()) + + + Args: + type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes. + agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agentdhal_core.AgentInstantiationContext` to access variables like the current runtime and agent ID. + expected_class (type[T] | None, optional): The expected class of the agent, used for runtime validation of the factory. Defaults to None. If None, no validation is performed. + """ + ... + + async def register_agent_instance( + self, + agent_instance: Agent, + agent_id: AgentId, + ) -> AgentId: + """Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions. + + .. note:: + + This is a low level API and usually the agent class's `register_instance` method should be used instead, as this also handles subscriptions automatically. + + Example: + + .. code-block:: python + + from dataclasses import dataclass + + from agentdhal_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event + from agentdhal_core.models import UserMessage + + + @dataclass + class MyMessage: + content: str + + + class MyAgent(RoutedAgent): + def __init__(self) -> None: + super().__init__("My core agent") + + @event + async def handler(self, message: UserMessage, context: MessageContext) -> None: + print("Event received: ", message.content) + + + async def main() -> None: + runtime: AgentRuntime = ... # type: ignore + agent = MyAgent() + await runtime.register_agent_instance( + agent_instance=agent, agent_id=AgentId(type="my_agent", key="default") + ) + + + import asyncio + + asyncio.run(main()) + + + Args: + agent_instance (Agent): A concrete instance of the agent. + agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`. + """ + ... + + # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 + async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] + """Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases. + + If the underlying agent is not accessible, this will raise an exception. + + Args: + id (AgentId): The agent id. + type (Type[T], optional): The expected type of the agent. Defaults to Agent. + + Returns: + T: The concrete agent instance. + + Raises: + LookupError: If the agent is not found. + NotAccessibleError: If the agent is not accessible, for example if it is located remotely. + TypeError: If the agent is not of the expected type. + """ + ... + + @overload + async def get(self, id: AgentId, /, *, lazy: bool = ...) -> AgentId: ... + + @overload + async def get(self, type: AgentType | str, /, key: str = ..., *, lazy: bool = ...) -> AgentId: ... + + async def get( + self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True + ) -> AgentId: ... + + async def save_state(self) -> Mapping[str, Any]: + """Save the state of the entire runtime, including all hosted agents. The only way to restore the state is to pass it to :meth:`load_state`. + + The structure of the state is implementation defined and can be any JSON serializable object. + + Returns: + Mapping[str, Any]: The saved state. + """ + ... + + async def load_state(self, state: Mapping[str, Any]) -> None: + """Load the state of the entire runtime, including all hosted agents. The state should be the same as the one returned by :meth:`save_state`. + + Args: + state (Mapping[str, Any]): The saved state. + """ + ... + + async def agent_metadata(self, agent: AgentId) -> AgentMetadata: + """Get the metadata for an agent. + + Args: + agent (AgentId): The agent id. + + Returns: + AgentMetadata: The agent metadata. + """ + ... + + async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: + """Save the state of a single agent. + + The structure of the state is implementation defined and can be any JSON serializable object. + + Args: + agent (AgentId): The agent id. + + Returns: + Mapping[str, Any]: The saved state. + """ + ... + + async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: + """Load the state of a single agent. + + Args: + agent (AgentId): The agent id. + state (Mapping[str, Any]): The saved state. + """ + ... + + async def add_subscription(self, subscription: Subscription) -> None: + """Add a new subscription that the runtime should fulfill when processing published messages + + Args: + subscription (Subscription): The subscription to add + """ + ... + + async def remove_subscription(self, id: str) -> None: + """Remove a subscription from the runtime + + Args: + id (str): id of the subscription to remove + + Raises: + LookupError: If the subscription does not exist + """ + ... + + def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None: + """Add a new message serialization serializer to the runtime + + Note: This will deduplicate serializers based on the type_name and data_content_type properties + + Args: + serializer (MessageSerializer[Any] | Sequence[MessageSerializer[Any]]): The serializer/s to add + """ + ... diff --git a/agent_dhal/agentdhal_core/_agent_type.py b/agent_dhal/agentdhal_core/_agent_type.py new file mode 100644 index 0000000..009f8c9 --- /dev/null +++ b/agent_dhal/agentdhal_core/_agent_type.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + +@dataclass(eq=True, frozen=True) +class AgentType: + type: str + """String representation of this agent type.""" diff --git a/agent_dhal/agentdhal_core/_base_agent.py b/agent_dhal/agentdhal_core/_base_agent.py new file mode 100644 index 0000000..6fcb54f --- /dev/null +++ b/agent_dhal/agentdhal_core/_base_agent.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import inspect +import warnings +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, final + +from typing_extensions import Self + +from ._agent import Agent +from ._agent_id import AgentId +from ._agent_instantiation import AgentInstantiationContext +from ._agent_metadata import AgentMetadata +from ._agent_runtime import AgentRuntime +from ._agent_type import AgentType +from ._cancellation_token import CancellationToken +from ._message_context import MessageContext +from ._serialization import MessageSerializer, try_get_known_serializers_for_type +from ._subscription import Subscription, UnboundSubscription +from ._subscription_context import SubscriptionInstantiationContext +from ._topic import TopicId +from ._type_prefix_subscription import TypePrefixSubscription +from ._type_subscription import TypeSubscription + +T = TypeVar("T", bound=Agent) + +BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent") + + +# Decorator for adding an unbound subscription to an agent +def subscription_factory(subscription: UnboundSubscription) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]: + """:meta private:""" + + def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]: + cls.internal_unbound_subscriptions_list.append(subscription) + return cls + + return decorator + + +def handles( + type: Type[Any], serializer: MessageSerializer[Any] | List[MessageSerializer[Any]] | None = None +) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]: + def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]: + if serializer is None: + serializer_list = try_get_known_serializers_for_type(type) + else: + serializer_list = [serializer] if not isinstance(serializer, Sequence) else serializer + + if len(serializer_list) == 0: + raise ValueError(f"No serializers found for type {type}. Please provide an explicit serializer.") + + cls.internal_extra_handles_types.append((type, serializer_list)) + return cls + + return decorator + + +class BaseAgent(ABC, Agent): + internal_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = [] + """:meta private:""" + internal_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = [] + """:meta private:""" + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + # Automatically set class_variable in each subclass so that they are not shared between subclasses + cls.internal_extra_handles_types = [] + cls.internal_unbound_subscriptions_list = [] + + @classmethod + def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]: + return cls.internal_extra_handles_types + + @classmethod + def _unbound_subscriptions(cls) -> List[UnboundSubscription]: + return cls.internal_unbound_subscriptions_list + + @property + def metadata(self) -> AgentMetadata: + assert self._id is not None + return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) + + def __init__(self, description: str) -> None: + if AgentInstantiationContext.is_in_factory_call(): + self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime() + self._id = AgentInstantiationContext.current_agent_id() + if not isinstance(description, str): + raise ValueError("Agent description must be a string") + self._description = description + + async def bind_id_and_runtime(self, id: AgentId, runtime: AgentRuntime) -> None: + if hasattr(self, "_id"): + if self._id != id: + raise RuntimeError("Agent is already bound to a different ID") + + if hasattr(self, "_runtime"): + if self._runtime != runtime: + raise RuntimeError("Agent is already bound to a different runtime") + + self._id = id + self._runtime = runtime + + @property + def type(self) -> str: + return self.id.type + + @property + def id(self) -> AgentId: + return self._id + + @property + def runtime(self) -> AgentRuntime: + return self._runtime + + @final + async def on_message(self, message: Any, ctx: MessageContext) -> Any: + return await self.on_message_impl(message, ctx) + + @abstractmethod + async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: ... + + async def send_message( + self, + message: Any, + recipient: AgentId, + *, + cancellation_token: CancellationToken | None = None, + message_id: str | None = None, + ) -> Any: + """See :py:meth:`agentdhal_core.AgentRuntime.send_message` for more information.""" + if cancellation_token is None: + cancellation_token = CancellationToken() + + return await self._runtime.send_message( + message, + sender=self.id, + recipient=recipient, + cancellation_token=cancellation_token, + message_id=message_id, + ) + + async def publish_message( + self, + message: Any, + topic_id: TopicId, + *, + cancellation_token: CancellationToken | None = None, + ) -> None: + await self._runtime.publish_message(message, topic_id, sender=self.id, cancellation_token=cancellation_token) + + async def save_state(self) -> Mapping[str, Any]: + warnings.warn("save_state not implemented", stacklevel=2) + return {} + + async def load_state(self, state: Mapping[str, Any]) -> None: + warnings.warn("load_state not implemented", stacklevel=2) + pass + + async def close(self) -> None: + pass + + async def register_instance( + self, + runtime: AgentRuntime, + agent_id: AgentId, + *, + skip_class_subscriptions: bool = True, + skip_direct_message_subscription: bool = False, + ) -> AgentId: + """ + This function is similar to `register` but is used for registering an instance of an agent. A subscription based on the agent ID is created and added to the runtime. + """ + agent_id = await runtime.register_agent_instance(agent_instance=self, agent_id=agent_id) + + id_subscription = TypeSubscription(topic_type=agent_id.key, agent_type=agent_id.type) + await runtime.add_subscription(id_subscription) + + if not skip_class_subscriptions: + with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)): + subscriptions: List[Subscription] = [] + for unbound_subscription in self._unbound_subscriptions(): + subscriptions_list_result = unbound_subscription() + if inspect.isawaitable(subscriptions_list_result): + subscriptions_list = await subscriptions_list_result + else: + subscriptions_list = subscriptions_list_result + + subscriptions.extend(subscriptions_list) + for subscription in subscriptions: + await runtime.add_subscription(subscription) + + if not skip_direct_message_subscription: + # Additionally adds a special prefix subscription for this agent to receive direct messages + try: + await runtime.add_subscription( + TypePrefixSubscription( + # The prefix MUST include ":" to avoid collisions with other agents + topic_type_prefix=agent_id.type + ":", + agent_type=agent_id.type, + ) + ) + except ValueError: + # We don't care if the subscription already exists + pass + + # TODO: deduplication + for _message_type, serializer in self._handles_types(): + runtime.add_message_serializer(serializer) + + return agent_id + + @classmethod + async def register( + cls, + runtime: AgentRuntime, + type: str, + factory: Callable[[], Self | Awaitable[Self]], + *, + skip_class_subscriptions: bool = False, + skip_direct_message_subscription: bool = False, + ) -> AgentType: + agent_type = AgentType(type) + agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls) + if not skip_class_subscriptions: + with SubscriptionInstantiationContext.populate_context(agent_type): + subscriptions: List[Subscription] = [] + for unbound_subscription in cls._unbound_subscriptions(): + subscriptions_list_result = unbound_subscription() + if inspect.isawaitable(subscriptions_list_result): + subscriptions_list = await subscriptions_list_result + else: + subscriptions_list = subscriptions_list_result + + subscriptions.extend(subscriptions_list) + for subscription in subscriptions: + await runtime.add_subscription(subscription) + + if not skip_direct_message_subscription: + # Additionally adds a special prefix subscription for this agent to receive direct messages + await runtime.add_subscription( + TypePrefixSubscription( + # The prefix MUST include ":" to avoid collisions with other agents + topic_type_prefix=agent_type.type + ":", + agent_type=agent_type.type, + ) + ) + + # TODO: deduplication + for _message_type, serializer in cls._handles_types(): + runtime.add_message_serializer(serializer) + + return agent_type diff --git a/agent_dhal/agentdhal_core/_cache_store.py b/agent_dhal/agentdhal_core/_cache_store.py new file mode 100644 index 0000000..bdf026b --- /dev/null +++ b/agent_dhal/agentdhal_core/_cache_store.py @@ -0,0 +1,70 @@ +from abc import ABC, abstractmethod +from typing import Dict, Generic, Optional, TypeVar + +from pydantic import BaseModel +from typing_extensions import Self + +from ._component_config import Component, ComponentBase + +T = TypeVar("T") + + +class CacheStore(ABC, Generic[T], ComponentBase[BaseModel]): + """ + This protocol defines the basic interface for store/cache operations. + + Sub-classes should handle the lifecycle of underlying storage. + """ + + component_type = "cache_store" + + @abstractmethod + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: + """ + Retrieve an item from the store. + + Args: + key: The key identifying the item in the store. + default (optional): The default value to return if the key is not found. + Defaults to None. + + Returns: + The value associated with the key if found, else the default value. + """ + ... + + @abstractmethod + def set(self, key: str, value: T) -> None: + """ + Set an item in the store. + + Args: + key: The key under which the item is to be stored. + value: The value to be stored in the store. + """ + ... + + +class InMemoryStoreConfig(BaseModel): + pass + + +class InMemoryStore(CacheStore[T], Component[InMemoryStoreConfig]): + component_provider_override = "agentdhal_core.InMemoryStore" + component_config_schema = InMemoryStoreConfig + + def __init__(self) -> None: + self.store: Dict[str, T] = {} + + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: + return self.store.get(key, default) + + def set(self, key: str, value: T) -> None: + self.store[key] = value + + def _to_config(self) -> InMemoryStoreConfig: + return InMemoryStoreConfig() + + @classmethod + def _from_config(cls, config: InMemoryStoreConfig) -> Self: + return cls() diff --git a/agent_dhal/agentdhal_core/_cancellation_token.py b/agent_dhal/agentdhal_core/_cancellation_token.py new file mode 100644 index 0000000..06e037e --- /dev/null +++ b/agent_dhal/agentdhal_core/_cancellation_token.py @@ -0,0 +1,46 @@ +import threading +from asyncio import Future +from typing import Any, Callable, List + + +class CancellationToken: + """A token used to cancel pending async calls""" + + def __init__(self) -> None: + self._cancelled: bool = False + self._lock: threading.Lock = threading.Lock() + self._callbacks: List[Callable[[], None]] = [] + + def cancel(self) -> None: + """Cancel pending async calls linked to this cancellation token.""" + with self._lock: + if not self._cancelled: + self._cancelled = True + for callback in self._callbacks: + callback() + + def is_cancelled(self) -> bool: + """Check if the CancellationToken has been used""" + with self._lock: + return self._cancelled + + def add_callback(self, callback: Callable[[], None]) -> None: + """Attach a callback that will be called when cancel is invoked""" + with self._lock: + if self._cancelled: + callback() + else: + self._callbacks.append(callback) + + def link_future(self, future: Future[Any]) -> Future[Any]: + """Link a pending async call to a token to allow its cancellation""" + with self._lock: + if self._cancelled: + future.cancel() + else: + + def _cancel() -> None: + future.cancel() + + self._callbacks.append(_cancel) + return future diff --git a/agent_dhal/agentdhal_core/_closure_agent.py b/agent_dhal/agentdhal_core/_closure_agent.py new file mode 100644 index 0000000..05befc4 --- /dev/null +++ b/agent_dhal/agentdhal_core/_closure_agent.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import inspect +import warnings +from typing import Any, Awaitable, Callable, List, Literal, Mapping, Protocol, Sequence, TypeVar, get_type_hints + +from ._agent_id import AgentId +from ._agent_instantiation import AgentInstantiationContext +from ._agent_metadata import AgentMetadata +from ._agent_runtime import AgentRuntime +from ._agent_type import AgentType +from ._base_agent import BaseAgent +from ._cancellation_token import CancellationToken +from ._message_context import MessageContext +from ._serialization import try_get_known_serializers_for_type +from ._subscription import Subscription +from ._subscription_context import SubscriptionInstantiationContext +from ._topic import TopicId +from ._type_helpers import get_types +from .exceptions import CantHandleException + +T = TypeVar("T") +ClosureAgentType = TypeVar("ClosureAgentType", bound="ClosureAgent") + + +def get_handled_types_from_closure( + closure: Callable[[ClosureAgent, T, MessageContext], Awaitable[Any]], +) -> Sequence[type]: + args = inspect.getfullargspec(closure)[0] + if len(args) != 3: + raise AssertionError("Closure must have 4 arguments") + + message_arg_name = args[1] + + type_hints = get_type_hints(closure) + + if "return" not in type_hints: + raise AssertionError("return not found in function signature") + + # Get the type of the message parameter + target_types = get_types(type_hints[message_arg_name]) + if target_types is None: + raise AssertionError("Message type not found") + + # print(type_hints) + return_types = get_types(type_hints["return"]) + + if return_types is None: + raise AssertionError("Return type not found") + + return target_types + + +class ClosureContext(Protocol): + @property + def id(self) -> AgentId: ... + + async def send_message( + self, + message: Any, + recipient: AgentId, + *, + cancellation_token: CancellationToken | None = None, + message_id: str | None = None, + ) -> Any: ... + + async def publish_message( + self, + message: Any, + topic_id: TopicId, + *, + cancellation_token: CancellationToken | None = None, + ) -> None: ... + + +class ClosureAgent(BaseAgent, ClosureContext): + def __init__( + self, + description: str, + closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], + *, + unknown_type_policy: Literal["error", "warn", "ignore"] = "warn", + ) -> None: + try: + runtime = AgentInstantiationContext.current_runtime() + id = AgentInstantiationContext.current_agent_id() + except Exception as e: + raise RuntimeError( + "ClosureAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated." + ) from e + + self._runtime: AgentRuntime = runtime + self._id: AgentId = id + self._description = description + handled_types = get_handled_types_from_closure(closure) + self._expected_types = handled_types + self._closure = closure + self._unknown_type_policy = unknown_type_policy + super().__init__(description) + + @property + def metadata(self) -> AgentMetadata: + assert self._id is not None + return AgentMetadata( + key=self._id.key, + type=self._id.type, + description=self._description, + ) + + @property + def id(self) -> AgentId: + return self._id + + @property + def runtime(self) -> AgentRuntime: + return self._runtime + + async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: + if type(message) not in self._expected_types: + if self._unknown_type_policy == "warn": + warnings.warn( + f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'error' to raise an exception, or 'ignore' to suppress this warning.", + stacklevel=1, + ) + return None + elif self._unknown_type_policy == "error": + raise CantHandleException( + f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'warn' to suppress this exception, or 'ignore' to suppress this warning." + ) + + return await self._closure(self, message, ctx) + + async def save_state(self) -> Mapping[str, Any]: + """Closure agents do not have state. So this method always returns an empty dictionary.""" + return {} + + async def load_state(self, state: Mapping[str, Any]) -> None: + """Closure agents do not have state. So this method does nothing.""" + pass + + @classmethod + async def register_closure( + cls, + runtime: AgentRuntime, + type: str, + closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], + *, + unknown_type_policy: Literal["error", "warn", "ignore"] = "warn", + skip_direct_message_subscription: bool = False, + description: str = "", + subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None, + ) -> AgentType: + """The closure agent allows you to define an agent using a closure, or function without needing to define a class. It allows values to be extracted out of the runtime. + + The closure can define the type of message which is expected, or `Any` can be used to accept any type of message. + + Example: + + .. code-block:: python + + import asyncio + from agentdhal_core import SingleThreadedAgentRuntime, MessageContext, ClosureAgent, ClosureContext + from dataclasses import dataclass + + from agentdhal_core._default_subscription import DefaultSubscription + from agentdhal_core._default_topic import DefaultTopicId + + + @dataclass + class MyMessage: + content: str + + + async def main(): + queue = asyncio.Queue[MyMessage]() + + async def output_result(_ctx: ClosureContext, message: MyMessage, ctx: MessageContext) -> None: + await queue.put(message) + + runtime = SingleThreadedAgentRuntime() + await ClosureAgent.register_closure( + runtime, "output_result", output_result, subscriptions=lambda: [DefaultSubscription()] + ) + + runtime.start() + await runtime.publish_message(MyMessage("Hello, world!"), DefaultTopicId()) + await runtime.stop_when_idle() + + result = await queue.get() + print(result) + + + asyncio.run(main()) + + + Args: + runtime (AgentRuntime): Runtime to register the agent to + type (str): Agent type of registered agent + closure (Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]): Closure to handle messages + unknown_type_policy (Literal["error", "warn", "ignore"], optional): What to do if a type is encountered that does not match the closure type. Defaults to "warn". + skip_direct_message_subscription (bool, optional): Do not add direct message subscription for this agent. Defaults to False. + description (str, optional): Description of what agent does. Defaults to "". + subscriptions (Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None, optional): List of subscriptions for this closure agent. Defaults to None. + + Returns: + AgentType: Type of the agent that was registered + """ + + def factory() -> ClosureAgent: + return ClosureAgent(description=description, closure=closure, unknown_type_policy=unknown_type_policy) + + assert len(cls._unbound_subscriptions()) == 0, "Closure agents are expected to have no class subscriptions" + agent_type = await cls.register( + runtime=runtime, + type=type, + factory=factory, # type: ignore + # There should be no need to process class subscriptions, as the closure agent does not have any subscriptions.s + skip_class_subscriptions=True, + skip_direct_message_subscription=skip_direct_message_subscription, + ) + + subscriptions_list: List[Subscription] = [] + if subscriptions is not None: + with SubscriptionInstantiationContext.populate_context(agent_type): + subscriptions_list_result = subscriptions() + if inspect.isawaitable(subscriptions_list_result): + subscriptions_list.extend(await subscriptions_list_result) + else: + # just ignore mypy here + subscriptions_list.extend(subscriptions_list_result) # type: ignore + + for subscription in subscriptions_list: + await runtime.add_subscription(subscription) + + handled_types = get_handled_types_from_closure(closure) + for message_type in handled_types: + # TODO: support custom serializers + serializer = try_get_known_serializers_for_type(message_type) + runtime.add_message_serializer(serializer) + + return agent_type diff --git a/agent_dhal/agentdhal_core/_component_config.py b/agent_dhal/agentdhal_core/_component_config.py new file mode 100644 index 0000000..aa66c8d --- /dev/null +++ b/agent_dhal/agentdhal_core/_component_config.py @@ -0,0 +1,361 @@ +from __future__ import annotations + +import importlib +import warnings +from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast, overload + +from pydantic import BaseModel +from typing_extensions import Self, TypeVar + +ComponentType = Literal["model", "agent", "tool", "termination", "token_provider", "workbench"] | str +ConfigT = TypeVar("ConfigT", bound=BaseModel) +FromConfigT = TypeVar("FromConfigT", bound=BaseModel, contravariant=True) +ToConfigT = TypeVar("ToConfigT", bound=BaseModel, covariant=True) + +T = TypeVar("T", bound=BaseModel, covariant=True) + + +class ComponentModel(BaseModel): + """Model class for a component. Contains all information required to instantiate a component.""" + + provider: str + """Describes how the component can be instantiated.""" + + component_type: ComponentType | None = None + """Logical type of the component. If missing, the component assumes the default type of the provider.""" + + version: int | None = None + """Version of the component specification. If missing, the component assumes whatever is the current version of the library used to load it. This is obviously dangerous and should be used for user authored ephmeral config. For all other configs version should be specified.""" + + component_version: int | None = None + """Version of the component. If missing, the component assumes the default version of the provider.""" + + description: str | None = None + """Description of the component.""" + + label: str | None = None + """Human readable label for the component. If missing the component assumes the class name of the provider.""" + + config: dict[str, Any] + """The schema validated config field is passed to a given class's implmentation of :py:meth:`agentdhal_core.ComponentConfigImpl._from_config` to create a new instance of the component class.""" + + +def _type_to_provider_str(t: type) -> str: + return f"{t.__module__}.{t.__qualname__}" + + +WELL_KNOWN_PROVIDERS = { + "azure_openai_chat_completion_client": "agentdhal_extensions.models.openai.AzureOpenAIChatCompletionClient", + "AzureOpenAIChatCompletionClient": "agentdhal_extensions.models.openai.AzureOpenAIChatCompletionClient", + "openai_chat_completion_client": "agentdhal_extensions.models.openai.OpenAIChatCompletionClient", + "OpenAIChatCompletionClient": "agentdhal_extensions.models.openai.OpenAIChatCompletionClient", +} + + +class ComponentFromConfig(Generic[FromConfigT]): + @classmethod + def _from_config(cls, config: FromConfigT) -> Self: + """Create a new instance of the component from a configuration object. + + Args: + config (T): The configuration object. + + Returns: + Self: The new instance of the component. + + :meta public: + """ + raise NotImplementedError("This component does not support dumping to config") + + @classmethod + def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self: + """Create a new instance of the component from a previous version of the configuration object. + + This is only called when the version of the configuration object is less than the current version, since in this case the schema is not known. + + Args: + config (Dict[str, Any]): The configuration object. + version (int): The version of the configuration object. + + Returns: + Self: The new instance of the component. + + :meta public: + """ + raise NotImplementedError("This component does not support loading from past versions") + + +class ComponentToConfig(Generic[ToConfigT]): + """The two methods a class must implement to be a component. + + Args: + Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`. + """ + + component_type: ClassVar[ComponentType] + """The logical type of the component.""" + component_version: ClassVar[int] = 1 + """The version of the component, if schema incompatibilities are introduced this should be updated.""" + component_provider_override: ClassVar[str | None] = None + """Override the provider string for the component. This should be used to prevent internal module names being a part of the module name.""" + component_description: ClassVar[str | None] = None + """A description of the component. If not provided, the docstring of the class will be used.""" + component_label: ClassVar[str | None] = None + """A human readable label for the component. If not provided, the component class name will be used.""" + + def _to_config(self) -> ToConfigT: + """Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance. + + Returns: + T: The configuration of the component. + + :meta public: + """ + raise NotImplementedError("This component does not support dumping to config") + + def dump_component(self) -> ComponentModel: + """Dump the component to a model that can be loaded back in. + + Raises: + TypeError: If the component is a local class. + + Returns: + ComponentModel: The model representing the component. + """ + if self.component_provider_override is not None: + provider = self.component_provider_override + else: + provider = _type_to_provider_str(self.__class__) + # Warn if internal module name is used, + if "._" in provider: + warnings.warn( + "Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.", + stacklevel=2, + ) + + if "" in provider: + raise TypeError("Cannot dump component with local class") + + if not hasattr(self, "component_type"): + raise AttributeError("component_type not defined") + + description = self.component_description + if description is None and self.__class__.__doc__: + # use docstring as description + docstring = self.__class__.__doc__.strip() + for marker in ["\n\nArgs:", "\n\nParameters:", "\n\nAttributes:", "\n\n"]: + docstring = docstring.split(marker)[0] + description = docstring.strip() + + obj_config = self._to_config().model_dump(exclude_none=True) + model = ComponentModel( + provider=provider, + component_type=self.component_type, + version=self.component_version, + component_version=self.component_version, + description=description, + label=self.component_label or self.__class__.__name__, + config=obj_config, + ) + return model + + +ExpectedType = TypeVar("ExpectedType") + + +class ComponentLoader: + @overload + @classmethod + def load_component(cls, model: ComponentModel | Dict[str, Any], expected: None = None) -> Self: ... + + @overload + @classmethod + def load_component(cls, model: ComponentModel | Dict[str, Any], expected: Type[ExpectedType]) -> ExpectedType: ... + + @classmethod + def load_component( + cls, model: ComponentModel | Dict[str, Any], expected: Type[ExpectedType] | None = None + ) -> Self | ExpectedType: + """Load a component from a model. Intended to be used with the return type of :py:meth:`agentdhal_core.ComponentConfig.dump_component`. + + Example: + + .. code-block:: python + + from agentdhal_core import ComponentModel + from agentdhal_core.models import ChatCompletionClient + + component: ComponentModel = ... # type: ignore + + model_client = ChatCompletionClient.load_component(component) + + Args: + model (ComponentModel): The model to load the component from. + + Returns: + Self: The loaded component. + + Args: + model (ComponentModel): _description_ + expected (Type[ExpectedType] | None, optional): Explicit type only if used directly on ComponentLoader. Defaults to None. + + Raises: + ValueError: If the provider string is invalid. + TypeError: Provider is not a subclass of ComponentConfigImpl, or the expected type does not match. + + Returns: + Self | ExpectedType: The loaded component. + """ + + # Use global and add further type checks + + if isinstance(model, dict): + loaded_model = ComponentModel(**model) + else: + loaded_model = model + + # First, do a look up in well known providers + if loaded_model.provider in WELL_KNOWN_PROVIDERS: + loaded_model.provider = WELL_KNOWN_PROVIDERS[loaded_model.provider] + + output = loaded_model.provider.rsplit(".", maxsplit=1) + if len(output) != 2: + raise ValueError("Invalid") + + module_path, class_name = output + module = importlib.import_module(module_path) + component_class = module.__getattribute__(class_name) + + if not is_component_class(component_class): + raise TypeError("Invalid component class") + + # We need to check the schema is valid + if not hasattr(component_class, "component_config_schema"): + raise AttributeError("component_config_schema not defined") + + if not hasattr(component_class, "component_type"): + raise AttributeError("component_type not defined") + + loaded_config_version = loaded_model.component_version or component_class.component_version + if loaded_config_version < component_class.component_version: + try: + instance = component_class._from_config_past_version(loaded_model.config, loaded_config_version) # type: ignore + except NotImplementedError as e: + raise NotImplementedError( + f"Tried to load component {component_class} which is on version {component_class.component_version} with a config on version {loaded_config_version} but _from_config_past_version is not implemented" + ) from e + else: + schema = component_class.component_config_schema # type: ignore + validated_config = schema.model_validate(loaded_model.config) + + # We're allowed to use the private method here + instance = component_class._from_config(validated_config) # type: ignore + + if expected is None and not isinstance(instance, cls): + raise TypeError("Expected type does not match") + elif expected is None: + return cast(Self, instance) + elif not isinstance(instance, expected): + raise TypeError("Expected type does not match") + else: + return cast(ExpectedType, instance) + + +class ComponentSchemaType(Generic[ConfigT]): + # Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context) + component_config_schema: Type[ConfigT] + """The Pydantic model class which represents the configuration of the component.""" + + required_class_vars = ["component_config_schema", "component_type"] + + def __init_subclass__(cls, **kwargs: Any): + super().__init_subclass__(**kwargs) + + if cls.__name__ != "Component" and not cls.__name__ == "_ConcreteComponent": + # TODO: validate provider is loadable + for var in cls.required_class_vars: + if not hasattr(cls, var): + warnings.warn( + f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component", + stacklevel=2, + ) + + +class ComponentBase(ComponentToConfig[ConfigT], ComponentLoader, Generic[ConfigT]): ... + + +class Component( + ComponentFromConfig[ConfigT], + ComponentSchemaType[ConfigT], + Generic[ConfigT], +): + """To create a component class, inherit from this class for the concrete class and ComponentBase on the interface. Then implement two class variables: + + - :py:attr:`component_config_schema` - A Pydantic model class which represents the configuration of the component. This is also the type parameter of Component. + - :py:attr:`component_type` - What is the logical type of the component. + + Example: + + .. code-block:: python + + from __future__ import annotations + + from pydantic import BaseModel + from agentdhal_core import Component + + + class Config(BaseModel): + value: str + + + class MyComponent(Component[Config]): + component_type = "custom" + component_config_schema = Config + + def __init__(self, value: str): + self.value = value + + def _to_config(self) -> Config: + return Config(value=self.value) + + @classmethod + def _from_config(cls, config: Config) -> MyComponent: + return cls(value=config.value) + """ + + def __init_subclass__(cls, **kwargs: Any): + super().__init_subclass__(**kwargs) + + if not is_component_class(cls): + warnings.warn( + f"Component class '{cls.__name__}' must subclass the following: ComponentFromConfig, ComponentToConfig, ComponentSchemaType, ComponentLoader, individually or with ComponentBase and Component. Look at the component config documentation or how OpenAIChatCompletionClient does it.", + stacklevel=2, + ) + + +# Should never be used directly, only for type checking +class _ConcreteComponent( + ComponentFromConfig[ConfigT], + ComponentSchemaType[ConfigT], + ComponentToConfig[ConfigT], + ComponentLoader, + Generic[ConfigT], +): ... + + +def is_component_instance(cls: Any) -> TypeGuard[_ConcreteComponent[BaseModel]]: + return ( + isinstance(cls, ComponentFromConfig) + and isinstance(cls, ComponentToConfig) + and isinstance(cls, ComponentSchemaType) + and isinstance(cls, ComponentLoader) + ) + + +def is_component_class(cls: type) -> TypeGuard[Type[_ConcreteComponent[BaseModel]]]: + return ( + issubclass(cls, ComponentFromConfig) + and issubclass(cls, ComponentToConfig) + and issubclass(cls, ComponentSchemaType) + and issubclass(cls, ComponentLoader) + ) diff --git a/agent_dhal/agentdhal_core/_constants.py b/agent_dhal/agentdhal_core/_constants.py new file mode 100644 index 0000000..be9bdf0 --- /dev/null +++ b/agent_dhal/agentdhal_core/_constants.py @@ -0,0 +1,9 @@ +ROOT_LOGGER_NAME = "agentdhal_core" +"""str: Logger name used for root logger""" + +EVENT_LOGGER_NAME = "agentdhal_core.events" +"""str: Logger name used for structured event logging""" + + +TRACE_LOGGER_NAME = "agentdhal_core.trace" +"""str: Logger name used for developer intended trace logging. The content and format of this log should not be depended upon.""" diff --git a/agent_dhal/agentdhal_core/_default_subscription.py b/agent_dhal/agentdhal_core/_default_subscription.py new file mode 100644 index 0000000..4c5251c --- /dev/null +++ b/agent_dhal/agentdhal_core/_default_subscription.py @@ -0,0 +1,53 @@ +from typing import Callable, Type, TypeVar, overload + +from ._agent_type import AgentType +from ._base_agent import BaseAgent, subscription_factory +from ._subscription_context import SubscriptionInstantiationContext +from ._type_subscription import TypeSubscription +from .exceptions import CantHandleException + + +class DefaultSubscription(TypeSubscription): + """The default subscription is designed to be a sensible default for applications that only need global scope for agents. + + This topic by default uses the "default" topic type and attempts to detect the agent type to use based on the instantiation context. + + Args: + topic_type (str, optional): The topic type to subscribe to. Defaults to "default". + agent_type (str, optional): The agent type to use for the subscription. Defaults to None, in which case it will attempt to detect the agent type based on the instantiation context. + """ + + def __init__(self, topic_type: str = "default", agent_type: str | AgentType | None = None): + if agent_type is None: + try: + agent_type = SubscriptionInstantiationContext.agent_type().type + except RuntimeError as e: + raise CantHandleException( + "If agent_type is not specified DefaultSubscription must be created within the subscription callback in AgentRuntime.register" + ) from e + + super().__init__(topic_type, agent_type) + + +BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent") + + +@overload +def default_subscription() -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]: ... + + +@overload +def default_subscription(cls: Type[BaseAgentType]) -> Type[BaseAgentType]: ... + + +def default_subscription( + cls: Type[BaseAgentType] | None = None, +) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]] | Type[BaseAgentType]: + if cls is None: + return subscription_factory(lambda: [DefaultSubscription()]) + else: + return subscription_factory(lambda: [DefaultSubscription()])(cls) + + +def type_subscription(topic_type: str) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]: + return subscription_factory(lambda: [DefaultSubscription(topic_type=topic_type)]) diff --git a/agent_dhal/agentdhal_core/_default_topic.py b/agent_dhal/agentdhal_core/_default_topic.py new file mode 100644 index 0000000..b5dde0a --- /dev/null +++ b/agent_dhal/agentdhal_core/_default_topic.py @@ -0,0 +1,23 @@ +from ._message_handler_context import MessageHandlerContext +from ._topic import TopicId + + +class DefaultTopicId(TopicId): + """DefaultTopicId provides a sensible default for the topic_id and source fields of a TopicId. + + If created in the context of a message handler, the source will be set to the agent_id of the message handler, otherwise it will be set to "default". + + Args: + type (str, optional): Topic type to publish message to. Defaults to "default". + source (str | None, optional): Topic source to publish message to. If None, the source will be set to the agent_id of the message handler if in the context of a message handler, otherwise it will be set to "default". Defaults to None. + """ + + def __init__(self, type: str = "default", source: str | None = None) -> None: + if source is None: + try: + source = MessageHandlerContext.agent_id().key + # If we aren't in the context of a message handler, we use the default source + except RuntimeError: + source = "default" + + super().__init__(type, source) diff --git a/agent_dhal/agentdhal_core/_function_utils.py b/agent_dhal/agentdhal_core/_function_utils.py new file mode 100644 index 0000000..8910278 --- /dev/null +++ b/agent_dhal/agentdhal_core/_function_utils.py @@ -0,0 +1,324 @@ +# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/function_utils.py +# Credit to original authors + +import inspect +import typing +from functools import partial +from logging import getLogger +from typing import ( + Annotated, + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, + get_args, + get_origin, +) + +from pydantic import BaseModel, Field, TypeAdapter, create_model # type: ignore +from pydantic_core import PydanticUndefined +from typing_extensions import Literal + +logger = getLogger(__name__) + +T = TypeVar("T") + + +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + """Get the signature of a function with type annotations. + + Args: + call: The function to get the signature for + + Returns: + The signature of the function with type annotations + """ + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + func_call = call.func if isinstance(call, partial) else call + type_hints = typing.get_type_hints(func_call, globalns, include_extras=True) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=type_hints[param.name], + ) + for param in signature.parameters.values() + ] + return_annotation = type_hints.get("return", inspect.Signature.empty) + typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation) + return typed_signature + + +def get_typed_return_annotation(call: Callable[..., Any]) -> Any: + """Get the return annotation of a function. + + Args: + call: The function to get the return annotation for + + Returns: + The return annotation of the function + """ + signature = inspect.signature(call) + annotation = signature.return_annotation + + if annotation is inspect.Signature.empty: + return None + + globalns = getattr(call, "__globals__", {}) + type_hints = typing.get_type_hints(call, globalns, include_extras=True) + return type_hints.get("return", inspect.Signature.empty) + + +def get_param_annotations( + typed_signature: inspect.Signature, +) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]: + """Get the type annotations of the parameters of a function + + Args: + typed_signature: The signature of the function with type annotations + + Returns: + A dictionary of the type annotations of the parameters of the function + """ + return { + k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty + } + + +class Parameters(BaseModel): + """Parameters of a function as defined by the OpenAI API""" + + type: Literal["object"] = "object" + properties: Dict[str, Dict[str, Any]] + required: List[str] + + +class Function(BaseModel): + """A function as defined by the OpenAI API""" + + description: Annotated[str, Field(description="Description of the function")] + name: Annotated[str, Field(description="Name of the function")] + parameters: Annotated[Parameters, Field(description="Parameters of the function")] + + +class ToolFunction(BaseModel): + """A function under tool as defined by the OpenAI API.""" + + type: Literal["function"] = "function" + function: Annotated[Function, Field(description="Function under tool")] + + +def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str: + # handles Annotated + if hasattr(v, "__metadata__"): + retval = v.__metadata__[0] + if isinstance(retval, str): + return retval + else: + raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.") + else: + return k + + +def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> Dict[str, Any]: + """Get a JSON schema for a parameter as defined by the OpenAI API + + Args: + k: The name of the parameter + v: The type of the parameter + default_values: The default values of the parameters of the function + + Returns: + A Pydanitc model for the parameter + """ + + schema = TypeAdapter(v).json_schema() + if k in default_values: + dv = default_values[k] + schema["default"] = dv + + schema["description"] = type2description(k, v) + + return schema + + +def get_required_params(typed_signature: inspect.Signature) -> List[str]: + """Get the required parameters of a function + + Args: + typed_signature: The signature of the function as returned by inspect.signature + + Returns: + A list of the required parameters of the function + """ + return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty] + + +def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]: + """Get default values of parameters of a function + + Args: + typed_signature: The signature of the function as returned by inspect.signature + + Returns: + A dictionary of the default values of the parameters of the function + """ + return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty} + + +def get_parameters( + required: List[str], + param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]], + default_values: Dict[str, Any], +) -> Parameters: + """Get the parameters of a function as defined by the OpenAI API + + Args: + required: The required parameters of the function + param_annotations: A dictionary of the type annotations of the parameters of the function + default_values: The default values of the parameters of the function + + Returns: + A Pydantic model for the parameters of the function + """ + return Parameters( + properties={ + k: get_parameter_json_schema(k, v, default_values) + for k, v in param_annotations.items() + if v is not inspect.Signature.empty + }, + required=required, + ) + + +def get_missing_annotations(typed_signature: inspect.Signature, required: List[str]) -> Tuple[Set[str], Set[str]]: + """Get the missing annotations of a function + + Ignores the parameters with default values as they are not required to be annotated, but logs a warning. + Args: + typed_signature: The signature of the function with type annotations + required: The required parameters of the function + + Returns: + A set of the missing annotations of the function + """ + all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty} + missing = all_missing.intersection(set(required)) + unannotated_with_default = all_missing.difference(missing) + return missing, unannotated_with_default + + +def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> Dict[str, Any]: + """Get a JSON schema for a function as defined by the OpenAI API + + Args: + f: The function to get the JSON schema for + name: The name of the function + description: The description of the function + + Returns: + A JSON schema for the function + + Raises: + TypeError: If the function is not annotated + + Examples: + + .. code-block:: python + + def f( + a: Annotated[str, "Parameter a"], + b: int = 2, + c: Annotated[float, "Parameter c"] = 0.1, + ) -> None: + pass + + + get_function_schema(f, description="function f") + + # {'type': 'function', + # 'function': {'description': 'function f', + # 'name': 'f', + # 'parameters': {'type': 'object', + # 'properties': {'a': {'type': 'str', 'description': 'Parameter a'}, + # 'b': {'type': 'int', 'description': 'b'}, + # 'c': {'type': 'float', 'description': 'Parameter c'}}, + # 'required': ['a']}}} + + """ + typed_signature = get_typed_signature(f) + required = get_required_params(typed_signature) + default_values = get_default_values(typed_signature) + param_annotations = get_param_annotations(typed_signature) + return_annotation = get_typed_return_annotation(f) + missing, unannotated_with_default = get_missing_annotations(typed_signature, required) + + if return_annotation is None: + logger.warning( + f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is " + + "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'." + ) + + if unannotated_with_default != set(): + unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)] + logger.warning( + f"The following parameters of the function '{f.__name__}' with default values are not annotated: " + + f"{', '.join(unannotated_with_default_s)}." + ) + + if missing != set(): + missing_s = [f"'{k}'" for k in sorted(missing)] + raise TypeError( + f"All parameters of the function '{f.__name__}' without default values must be annotated. " + + f"The annotations are missing for the following parameters: {', '.join(missing_s)}" + ) + + fname = name if name else f.__name__ + + parameters = get_parameters(required, param_annotations, default_values=default_values) + + function = ToolFunction( + function=Function( + description=description, + name=fname, + parameters=parameters, + ) + ) + + return function.model_dump() + + +def normalize_annotated_type(type_hint: Type[Any]) -> Type[Any]: + """Normalize typing.Annotated types to the inner type.""" + if get_origin(type_hint) is Annotated: + # Extract the inner type from Annotated + return get_args(type_hint)[0] # type: ignore + return type_hint + + +def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]: + fields: Dict[str, tuple[Type[Any], Any]] = {} + for param_name, param in sig.parameters.items(): + # This is handled externally + if param_name == "cancellation_token": + continue + + if param.annotation is inspect.Parameter.empty: + raise ValueError("No annotation") + + type = normalize_annotated_type(param.annotation) + description = type2description(param_name, param.annotation) + default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined + + fields[param_name] = (type, Field(default=default_value, description=description)) + + return cast(BaseModel, create_model(name, **fields)) # type: ignore diff --git a/agent_dhal/agentdhal_core/_image.py b/agent_dhal/agentdhal_core/_image.py new file mode 100644 index 0000000..73d1b4b --- /dev/null +++ b/agent_dhal/agentdhal_core/_image.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import base64 +import re +from io import BytesIO +from pathlib import Path +from typing import Any, Dict, cast + +from PIL import Image as PILImage +from pydantic import GetCoreSchemaHandler, ValidationInfo +from pydantic_core import core_schema +from typing_extensions import Literal + + +class Image: + """Represents an image. + + + Example: + + Loading an image from a URL: + + .. code-block:: python + + from agentdhal_core import Image + from PIL import Image as PILImage + import aiohttp + import asyncio + + + async def from_url(url: str) -> Image: + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + content = await response.read() + return Image.from_pil(PILImage.open(content)) + + + image = asyncio.run(from_url("https://example.com/image")) + + """ + + def __init__(self, image: PILImage.Image): + self.image: PILImage.Image = image.convert("RGB") + + @classmethod + def from_pil(cls, pil_image: PILImage.Image) -> Image: + return cls(pil_image) + + @classmethod + def from_uri(cls, uri: str) -> Image: + if not re.match(r"data:image/(?:png|jpeg);base64,", uri): + raise ValueError("Invalid URI format. It should be a base64 encoded image URI.") + + # A URI. Remove the prefix and decode the base64 string. + base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", uri) + return cls.from_base64(base64_data) + + @classmethod + def from_base64(cls, base64_str: str) -> Image: + return cls(PILImage.open(BytesIO(base64.b64decode(base64_str)))) + + def to_base64(self) -> str: + buffered = BytesIO() + self.image.save(buffered, format="PNG") + content = buffered.getvalue() + return base64.b64encode(content).decode("utf-8") + + @classmethod + def from_file(cls, file_path: Path) -> Image: + return cls(PILImage.open(file_path)) + + def _repr_html_(self) -> str: + # Show the image in Jupyter notebook + return f'' + + @property + def data_uri(self) -> str: + return _convert_base64_to_data_uri(self.to_base64()) + + # Returns openai.types.chat.ChatCompletionContentPartImageParam, which is a TypedDict + # We don't use the explicit type annotation so that we can avoid a dependency on the OpenAI Python SDK in this package. + def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> Dict[str, Any]: + return {"type": "image_url", "image_url": {"url": self.data_uri, "detail": detail}} + + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + # Custom validation + def validate(value: Any, validation_info: ValidationInfo) -> Image: + if isinstance(value, dict): + base_64 = cast(str | None, value.get("data")) # type: ignore + if base_64 is None: + raise ValueError("Expected 'data' key in the dictionary") + return cls.from_base64(base_64) + elif isinstance(value, cls): + return value + else: + raise TypeError(f"Expected dict or {cls.__name__} instance, got {type(value)}") + + # Custom serialization + def serialize(value: Image) -> dict[str, Any]: + return {"data": value.to_base64()} + + return core_schema.with_info_after_validator_function( + validate, + core_schema.any_schema(), # Accept any type; adjust if needed + serialization=core_schema.plain_serializer_function_ser_schema(serialize), + ) + + +def _convert_base64_to_data_uri(base64_image: str) -> str: + def _get_mime_type_from_data_uri(base64_image: str) -> str: + # Decode the base64 string + image_data = base64.b64decode(base64_image) + # Check the first few bytes for known signatures + if image_data.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + elif image_data.startswith(b"\x89PNG\r\n\x1a\n"): + return "image/png" + elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"): + return "image/gif" + elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP": + return "image/webp" + return "image/jpeg" # use jpeg for unknown formats, best guess. + + mime_type = _get_mime_type_from_data_uri(base64_image) + data_uri = f"data:{mime_type};base64,{base64_image}" + return data_uri diff --git a/agent_dhal/agentdhal_core/_intervention.py b/agent_dhal/agentdhal_core/_intervention.py new file mode 100644 index 0000000..2973455 --- /dev/null +++ b/agent_dhal/agentdhal_core/_intervention.py @@ -0,0 +1,83 @@ +from typing import Any, Protocol, final + +from ._agent_id import AgentId +from ._message_context import MessageContext + +__all__ = [ + "DropMessage", + "InterventionHandler", + "DefaultInterventionHandler", +] + + +@final +class DropMessage: + """Marker type for signalling that a message should be dropped by an intervention handler. The type itself should be returned from the handler.""" + + ... + + +class InterventionHandler(Protocol): + """An intervention handler is a class that can be used to modify, log or drop messages that are being processed by the :class:`agentdhal_core.base.AgentRuntime`. + + The handler is called when the message is submitted to the runtime. + + Currently the only runtime which supports this is the :class:`agentdhal_core.base.SingleThreadedAgentRuntime`. + + Note: Returning None from any of the intervention handler methods will result in a warning being issued and treated as "no change". If you intend to drop a message, you should return :class:`DropMessage` explicitly. + + Example: + + .. code-block:: python + + from agentdhal_core import DefaultInterventionHandler, MessageContext, AgentId, SingleThreadedAgentRuntime + from dataclasses import dataclass + from typing import Any + + + @dataclass + class MyMessage: + content: str + + + class MyInterventionHandler(DefaultInterventionHandler): + async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> MyMessage: + if isinstance(message, MyMessage): + message.content = message.content.upper() + return message + + + runtime = SingleThreadedAgentRuntime(intervention_handlers=[MyInterventionHandler()]) + + """ + + async def on_send( + self, message: Any, *, message_context: MessageContext, recipient: AgentId + ) -> Any | type[DropMessage]: + """Called when a message is submitted to the AgentRuntime using :meth:`agentdhal_core.base.AgentRuntime.send_message`.""" + ... + + async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]: + """Called when a message is published to the AgentRuntime using :meth:`agentdhal_core.base.AgentRuntime.publish_message`.""" + ... + + async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]: + """Called when a response is received by the AgentRuntime from an Agent's message handler returning a value.""" + ... + + +class DefaultInterventionHandler(InterventionHandler): + """Simple class that provides a default implementation for all intervention + handler methods, that simply returns the message unchanged. Allows for easy + subclassing to override only the desired methods.""" + + async def on_send( + self, message: Any, *, message_context: MessageContext, recipient: AgentId + ) -> Any | type[DropMessage]: + return message + + async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]: + return message + + async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]: + return message diff --git a/agent_dhal/agentdhal_core/_message_context.py b/agent_dhal/agentdhal_core/_message_context.py new file mode 100644 index 0000000..c5c0055 --- /dev/null +++ b/agent_dhal/agentdhal_core/_message_context.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +from ._agent_id import AgentId +from ._cancellation_token import CancellationToken +from ._topic import TopicId + + +@dataclass +class MessageContext: + sender: AgentId | None + topic_id: TopicId | None + is_rpc: bool + cancellation_token: CancellationToken + message_id: str diff --git a/agent_dhal/agentdhal_core/_message_handler_context.py b/agent_dhal/agentdhal_core/_message_handler_context.py new file mode 100644 index 0000000..9e5a6a9 --- /dev/null +++ b/agent_dhal/agentdhal_core/_message_handler_context.py @@ -0,0 +1,31 @@ +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, ClassVar, Generator + +from ._agent_id import AgentId + + +class MessageHandlerContext: + def __init__(self) -> None: + raise RuntimeError( + "MessageHandlerContext cannot be instantiated. It is a static class that provides context management for message handling." + ) + + _MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("_MESSAGE_HANDLER_CONTEXT") + + @classmethod + @contextmanager + def populate_context(cls, ctx: AgentId) -> Generator[None, Any, None]: + """:meta private:""" + token = MessageHandlerContext._MESSAGE_HANDLER_CONTEXT.set(ctx) + try: + yield + finally: + MessageHandlerContext._MESSAGE_HANDLER_CONTEXT.reset(token) + + @classmethod + def agent_id(cls) -> AgentId: + try: + return cls._MESSAGE_HANDLER_CONTEXT.get() + except LookupError as e: + raise RuntimeError("MessageHandlerContext.agent_id() must be called within a message handler.") from e diff --git a/agent_dhal/agentdhal_core/_queue.py b/agent_dhal/agentdhal_core/_queue.py new file mode 100644 index 0000000..699921a --- /dev/null +++ b/agent_dhal/agentdhal_core/_queue.py @@ -0,0 +1,264 @@ +# Copy of Asyncio queue: https://github.com/python/cpython/blob/main/Lib/asyncio/queues.py +# So that shutdown can be used in <3.13 +# Modified to work outside of the asyncio package + +import asyncio +import collections +import threading +from typing import Generic, TypeVar + +_global_lock = threading.Lock() + + +class _LoopBoundMixin: + _loop = None + + def _get_loop(self) -> asyncio.AbstractEventLoop: + loop = asyncio.get_running_loop() + + if self._loop is None: + with _global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + + +class QueueShutDown(Exception): + """Raised when putting on to or getting from a shut-down Queue.""" + + pass + + +T = TypeVar("T") + + +class Queue(_LoopBoundMixin, Generic[T]): + def __init__(self, maxsize: int = 0): + self._maxsize = maxsize + self._getters = collections.deque[asyncio.Future[None]]() + self._putters = collections.deque[asyncio.Future[None]]() + self._unfinished_tasks = 0 + self._finished = asyncio.Event() + self._finished.set() + self._queue = collections.deque[T]() + self._is_shutdown = False + + # These three are overridable in subclasses. + + def _get(self) -> T: + return self._queue.popleft() + + def _put(self, item: T) -> None: + self._queue.append(item) + + # End of the overridable methods. + + def _wakeup_next(self, waiters: collections.deque[asyncio.Future[None]]) -> None: + # Wake up the next waiter (if any) that isn't cancelled. + while waiters: + waiter = waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + break + + def __repr__(self) -> str: + return f"<{type(self).__name__} at {id(self):#x} {self._format()}>" + + def __str__(self) -> str: + return f"<{type(self).__name__} {self._format()}>" + + def _format(self) -> str: + result = f"maxsize={self._maxsize!r}" + if getattr(self, "_queue", None): + result += f" _queue={list(self._queue)!r}" + if self._getters: + result += f" _getters[{len(self._getters)}]" + if self._putters: + result += f" _putters[{len(self._putters)}]" + if self._unfinished_tasks: + result += f" tasks={self._unfinished_tasks}" + if self._is_shutdown: + result += " shutdown" + return result + + def qsize(self) -> int: + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self) -> int: + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self) -> bool: + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self) -> bool: + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() >= self._maxsize + + async def put(self, item: T) -> None: + """Put an item into the queue. + + Put an item into the queue. If the queue is full, wait until a free + slot is available before adding item. + + Raises QueueShutDown if the queue has been shut down. + """ + while self.full(): + if self._is_shutdown: + raise QueueShutDown + putter = self._get_loop().create_future() + self._putters.append(putter) + try: + await putter + except: + putter.cancel() # Just in case putter is not done yet. + try: + # Clean self._putters from canceled putters. + self._putters.remove(putter) + except ValueError: + # The putter could be removed from self._putters by a + # previous get_nowait call or a shutdown call. + pass + if not self.full() and not putter.cancelled(): + # We were woken up by get_nowait(), but can't take + # the call. Wake up the next in line. + self._wakeup_next(self._putters) + raise + return self.put_nowait(item) + + def put_nowait(self, item: T) -> None: + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise QueueFull. + + Raises QueueShutDown if the queue has been shut down. + """ + if self._is_shutdown: + raise QueueShutDown + if self.full(): + raise asyncio.QueueFull + self._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + self._wakeup_next(self._getters) + + async def get(self) -> T: + """Remove and return an item from the queue. + + If queue is empty, wait until an item is available. + + Raises QueueShutDown if the queue has been shut down and is empty, or + if the queue has been shut down immediately. + """ + while self.empty(): + if self._is_shutdown and self.empty(): + raise QueueShutDown + getter = self._get_loop().create_future() + self._getters.append(getter) + try: + await getter + except: + getter.cancel() # Just in case getter is not done yet. + try: + # Clean self._getters from canceled getters. + self._getters.remove(getter) + except ValueError: + # The getter could be removed from self._getters by a + # previous put_nowait call, or a shutdown call. + pass + if not self.empty() and not getter.cancelled(): + # We were woken up by put_nowait(), but can't take + # the call. Wake up the next in line. + self._wakeup_next(self._getters) + raise + return self.get_nowait() + + def get_nowait(self) -> T: + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise QueueEmpty. + + Raises QueueShutDown if the queue has been shut down and is empty, or + if the queue has been shut down immediately. + """ + if self.empty(): + if self._is_shutdown: + raise QueueShutDown + raise asyncio.QueueEmpty + item = self._get() + self._wakeup_next(self._putters) + return item + + def task_done(self) -> None: + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + shutdown(immediate=True) calls task_done() for each remaining item in + the queue. + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError("task_done() called too many times") + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + async def join(self) -> None: + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer calls task_done() to + indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + await self._finished.wait() + + def shutdown(self, immediate: bool = False) -> None: + """Shut-down the queue, making queue gets and puts raise QueueShutDown. + + By default, gets will only raise once the queue is empty. Set + 'immediate' to True to make gets raise immediately instead. + + All blocked callers of put() and get() will be unblocked. If + 'immediate', a task is marked as done for each item remaining in + the queue, which may unblock callers of join(). + """ + self._is_shutdown = True + if immediate: + while not self.empty(): + self._get() + if self._unfinished_tasks > 0: + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + # All getters need to re-check queue-empty to raise ShutDown + while self._getters: + getter = self._getters.popleft() + if not getter.done(): + getter.set_result(None) + while self._putters: + putter = self._putters.popleft() + if not putter.done(): + putter.set_result(None) diff --git a/agent_dhal/agentdhal_core/_routed_agent.py b/agent_dhal/agentdhal_core/_routed_agent.py new file mode 100644 index 0000000..29c8df6 --- /dev/null +++ b/agent_dhal/agentdhal_core/_routed_agent.py @@ -0,0 +1,518 @@ +import logging +from functools import wraps +from typing import ( + Any, + Callable, + Coroutine, + DefaultDict, + List, + Literal, + Protocol, + Sequence, + Tuple, + Type, + TypeVar, + cast, + get_type_hints, + overload, + runtime_checkable, +) + +from ._base_agent import BaseAgent +from ._message_context import MessageContext +from ._serialization import MessageSerializer, try_get_known_serializers_for_type +from ._type_helpers import AnyType, get_types +from .exceptions import CantHandleException + +logger = logging.getLogger("agentdhal_core") + +AgentT = TypeVar("AgentT") +ReceivesT = TypeVar("ReceivesT") +ProducesT = TypeVar("ProducesT", covariant=True) + +# TODO: Generic typevar bound binding U to agent type +# Can't do because python doesnt support it + + +# Pyright and mypy disagree on the variance of ReceivesT. Mypy thinks it should be contravariant here. +# Revisit this later to see if we can remove the ignore. +@runtime_checkable +class MessageHandler(Protocol[AgentT, ReceivesT, ProducesT]): # type: ignore + target_types: Sequence[type] + produces_types: Sequence[type] + is_message_handler: Literal[True] + router: Callable[[ReceivesT, MessageContext], bool] + + # agent_instance binds to self in the method + @staticmethod + async def __call__(agent_instance: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: ... + + +# NOTE: this works on concrete types and not inheritance +# TODO: Use a protocol for the outer function to check checked arg names + + +@overload +def message_handler( + func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]], +) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ... + + +@overload +def message_handler( + func: None = None, + *, + match: None = ..., + strict: bool = ..., +) -> Callable[ + [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], + MessageHandler[AgentT, ReceivesT, ProducesT], +]: ... + + +@overload +def message_handler( + func: None = None, + *, + match: Callable[[ReceivesT, MessageContext], bool], + strict: bool = ..., +) -> Callable[ + [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], + MessageHandler[AgentT, ReceivesT, ProducesT], +]: ... + + +def message_handler( + func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None, + *, + strict: bool = True, + match: None | Callable[[ReceivesT, MessageContext], bool] = None, +) -> ( + Callable[ + [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], + MessageHandler[AgentT, ReceivesT, ProducesT], + ] + | MessageHandler[AgentT, ReceivesT, ProducesT] +): + """Decorator for generic message handlers. + + Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle both event and RPC messages. + These methods must have a specific signature that needs to be followed for it to be valid: + + - The method must be an `async` method. + - The method must be decorated with the `@message_handler` decorator. + - The method must have exactly 3 arguments: + 1. `self` + 2. `message`: The message to be handled, this must be type-hinted with the message type that it is intended to handle. + 3. `ctx`: A :class:`agentdhal_core.MessageContext` object. + - The method must be type hinted with what message types it can return as a response, or it can return `None` if it does not return anything. + + Handlers can handle more than one message type by accepting a Union of the message types. It can also return more than one message type by returning a Union of the message types. + + Args: + func: The function to be decorated. + strict: If `True`, the handler will raise an exception if the message type or return type is not in the target types. If `False`, it will log a warning instead. + match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called. + """ + + def decorator( + func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]], + ) -> MessageHandler[AgentT, ReceivesT, ProducesT]: + type_hints = get_type_hints(func) + if "message" not in type_hints: + raise AssertionError("message parameter not found in function signature") + + if "return" not in type_hints: + raise AssertionError("return parameter not found in function signature") + + # Get the type of the message parameter + target_types = get_types(type_hints["message"]) + if target_types is None: + raise AssertionError("Message type not found") + + # print(type_hints) + return_types = get_types(type_hints["return"]) + + if return_types is None: + raise AssertionError("Return type not found") + + # Convert target_types to list and stash + + @wraps(func) + async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: + if type(message) not in target_types: + if strict: + raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") + else: + logger.warning(f"Message type {type(message)} not in target types {target_types}") + + return_value = await func(self, message, ctx) + + if AnyType not in return_types and type(return_value) not in return_types: + if strict: + raise ValueError(f"Return type {type(return_value)} not in return types {return_types}") + else: + logger.warning(f"Return type {type(return_value)} not in return types {return_types}") + + return return_value + + wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) + wrapper_handler.target_types = list(target_types) + wrapper_handler.produces_types = list(return_types) + wrapper_handler.is_message_handler = True + wrapper_handler.router = match or (lambda _message, _ctx: True) + + return wrapper_handler + + if func is None and not callable(func): + return decorator + elif callable(func): + return decorator(func) + else: + raise ValueError("Invalid arguments") + + +@overload +def event( + func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]], +) -> MessageHandler[AgentT, ReceivesT, None]: ... + + +@overload +def event( + func: None = None, + *, + match: None = ..., + strict: bool = ..., +) -> Callable[ + [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]], + MessageHandler[AgentT, ReceivesT, None], +]: ... + + +@overload +def event( + func: None = None, + *, + match: Callable[[ReceivesT, MessageContext], bool], + strict: bool = ..., +) -> Callable[ + [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]], + MessageHandler[AgentT, ReceivesT, None], +]: ... + + +def event( + func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]] = None, + *, + strict: bool = True, + match: None | Callable[[ReceivesT, MessageContext], bool] = None, +) -> ( + Callable[ + [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]], + MessageHandler[AgentT, ReceivesT, None], + ] + | MessageHandler[AgentT, ReceivesT, None] +): + """Decorator for event message handlers. + + Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle event messages. + These methods must have a specific signature that needs to be followed for it to be valid: + + - The method must be an `async` method. + - The method must be decorated with the `@message_handler` decorator. + - The method must have exactly 3 arguments: + 1. `self` + 2. `message`: The event message to be handled, this must be type-hinted with the message type that it is intended to handle. + 3. `ctx`: A :class:`agentdhal_core.MessageContext` object. + - The method must return `None`. + + Handlers can handle more than one message type by accepting a Union of the message types. + + Args: + func: The function to be decorated. + strict: If `True`, the handler will raise an exception if the message type is not in the target types. If `False`, it will log a warning instead. + match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called. + """ + + def decorator( + func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]], + ) -> MessageHandler[AgentT, ReceivesT, None]: + type_hints = get_type_hints(func) + if "message" not in type_hints: + raise AssertionError("message parameter not found in function signature") + + if "return" not in type_hints: + raise AssertionError("return parameter not found in function signature") + + # Get the type of the message parameter + target_types = get_types(type_hints["message"]) + if target_types is None: + raise AssertionError("Message type not found. Please provide a type hint for the message parameter.") + + return_types = get_types(type_hints["return"]) + + if return_types is None: + raise AssertionError("Return type not found. Please use `None` as the type hint of the return type.") + + # Convert target_types to list and stash + + @wraps(func) + async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None: + if type(message) not in target_types: + if strict: + raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") + else: + logger.warning(f"Message type {type(message)} not in target types {target_types}") + + return_value = await func(self, message, ctx) # type: ignore + + if return_value is not None: + if strict: + raise ValueError(f"Return type {type(return_value)} is not None.") + else: + logger.warning(f"Return type {type(return_value)} is not None. It will be ignored.") + + return None + + wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, None], wrapper) + wrapper_handler.target_types = list(target_types) + wrapper_handler.produces_types = list(return_types) + wrapper_handler.is_message_handler = True + # Wrap the match function with a check on the is_rpc flag. + wrapper_handler.router = lambda _message, _ctx: (not _ctx.is_rpc) and (match(_message, _ctx) if match else True) + + return wrapper_handler + + if func is None and not callable(func): + return decorator + elif callable(func): + return decorator(func) + else: + raise ValueError("Invalid arguments") + + +@overload +def rpc( + func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]], +) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ... + + +@overload +def rpc( + func: None = None, + *, + match: None = ..., + strict: bool = ..., +) -> Callable[ + [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], + MessageHandler[AgentT, ReceivesT, ProducesT], +]: ... + + +@overload +def rpc( + func: None = None, + *, + match: Callable[[ReceivesT, MessageContext], bool], + strict: bool = ..., +) -> Callable[ + [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], + MessageHandler[AgentT, ReceivesT, ProducesT], +]: ... + + +def rpc( + func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None, + *, + strict: bool = True, + match: None | Callable[[ReceivesT, MessageContext], bool] = None, +) -> ( + Callable[ + [Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], + MessageHandler[AgentT, ReceivesT, ProducesT], + ] + | MessageHandler[AgentT, ReceivesT, ProducesT] +): + """Decorator for RPC message handlers. + + Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle RPC messages. + These methods must have a specific signature that needs to be followed for it to be valid: + + - The method must be an `async` method. + - The method must be decorated with the `@message_handler` decorator. + - The method must have exactly 3 arguments: + 1. `self` + 2. `message`: The message to be handled, this must be type-hinted with the message type that it is intended to handle. + 3. `ctx`: A :class:`agentdhal_core.MessageContext` object. + - The method must be type hinted with what message types it can return as a response, or it can return `None` if it does not return anything. + + Handlers can handle more than one message type by accepting a Union of the message types. It can also return more than one message type by returning a Union of the message types. + + Args: + func: The function to be decorated. + strict: If `True`, the handler will raise an exception if the message type or return type is not in the target types. If `False`, it will log a warning instead. + match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called. + """ + + def decorator( + func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]], + ) -> MessageHandler[AgentT, ReceivesT, ProducesT]: + type_hints = get_type_hints(func) + if "message" not in type_hints: + raise AssertionError("message parameter not found in function signature") + + if "return" not in type_hints: + raise AssertionError("return parameter not found in function signature") + + # Get the type of the message parameter + target_types = get_types(type_hints["message"]) + if target_types is None: + raise AssertionError("Message type not found") + + # print(type_hints) + return_types = get_types(type_hints["return"]) + + if return_types is None: + raise AssertionError("Return type not found") + + # Convert target_types to list and stash + + @wraps(func) + async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: + if type(message) not in target_types: + if strict: + raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") + else: + logger.warning(f"Message type {type(message)} not in target types {target_types}") + + return_value = await func(self, message, ctx) + + if AnyType not in return_types and type(return_value) not in return_types: + if strict: + raise ValueError(f"Return type {type(return_value)} not in return types {return_types}") + else: + logger.warning(f"Return type {type(return_value)} not in return types {return_types}") + + return return_value + + wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) + wrapper_handler.target_types = list(target_types) + wrapper_handler.produces_types = list(return_types) + wrapper_handler.is_message_handler = True + wrapper_handler.router = lambda _message, _ctx: (_ctx.is_rpc) and (match(_message, _ctx) if match else True) + + return wrapper_handler + + if func is None and not callable(func): + return decorator + elif callable(func): + return decorator(func) + else: + raise ValueError("Invalid arguments") + + +class RoutedAgent(BaseAgent): + """A base class for agents that route messages to handlers based on the type of the message + and optional matching functions. + + To create a routed agent, subclass this class and add message handlers as methods decorated with + either :func:`event` or :func:`rpc` decorator. + + Example: + + .. code-block:: python + + from dataclasses import dataclass + from agentdhal_core import MessageContext + from agentdhal_core import RoutedAgent, event, rpc + + + @dataclass + class Message: + pass + + + @dataclass + class MessageWithContent: + content: str + + + @dataclass + class Response: + pass + + + class MyAgent(RoutedAgent): + def __init__(self): + super().__init__("MyAgent") + + @event + async def handle_event_message(self, message: Message, ctx: MessageContext) -> None: + assert ctx.topic_id is not None + await self.publish_message(MessageWithContent("event handled"), ctx.topic_id) + + @rpc(match=lambda message, ctx: message.content == "special") # type: ignore + async def handle_special_rpc_message(self, message: MessageWithContent, ctx: MessageContext) -> Response: + return Response() + """ + + def __init__(self, description: str) -> None: + # Self is already bound to the handlers + self._handlers: DefaultDict[ + Type[Any], + List[MessageHandler[RoutedAgent, Any, Any]], + ] = DefaultDict(list) + + handlers = self._discover_handlers() + for message_handler in handlers: + for target_type in message_handler.target_types: + self._handlers[target_type].append(message_handler) + + super().__init__(description) + + async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None: + """Handle a message by routing it to the appropriate message handler. + Do not override this method in subclasses. Instead, add message handlers as methods decorated with + either the :func:`event` or :func:`rpc` decorator.""" + key_type: Type[Any] = type(message) # type: ignore + handlers = self._handlers.get(key_type) # type: ignore + if handlers is not None: + # Iterate over all handlers for this matching message type. + # Call the first handler whose router returns True and then return the result. + for h in handlers: + if h.router(message, ctx): + return await h(self, message, ctx) + return await self.on_unhandled_message(message, ctx) # type: ignore + + async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: + """Called when a message is received that does not have a matching message handler. + The default implementation logs an info message.""" + logger.info(f"Unhandled message: {message}") + + @classmethod + def _discover_handlers(cls) -> Sequence[MessageHandler[Any, Any, Any]]: + handlers: List[MessageHandler[Any, Any, Any]] = [] + for attr in dir(cls): + if callable(getattr(cls, attr, None)): + # Since we are getting it from the class, self is not bound + handler = getattr(cls, attr) + if hasattr(handler, "is_message_handler"): + handlers.append(cast(MessageHandler[Any, Any, Any], handler)) + return handlers + + @classmethod + def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]: + # TODO handle deduplication + handlers = cls._discover_handlers() + types: List[Tuple[Type[Any], List[MessageSerializer[Any]]]] = [] + types.extend(cls.internal_extra_handles_types) + for handler in handlers: + for t in handler.target_types: + # TODO: support different serializers + serializers = try_get_known_serializers_for_type(t) + if len(serializers) == 0: + raise ValueError(f"No serializers found for type {t}.") + + types.append((t, try_get_known_serializers_for_type(t))) + return types diff --git a/agent_dhal/agentdhal_core/_runtime_impl_helpers.py b/agent_dhal/agentdhal_core/_runtime_impl_helpers.py new file mode 100644 index 0000000..31bfc04 --- /dev/null +++ b/agent_dhal/agentdhal_core/_runtime_impl_helpers.py @@ -0,0 +1,78 @@ +from collections import defaultdict +from typing import Awaitable, Callable, DefaultDict, List, Sequence, Set + +from ._agent import Agent +from ._agent_id import AgentId +from ._agent_type import AgentType +from ._subscription import Subscription +from ._topic import TopicId + + +async def get_impl( + *, + id_or_type: AgentId | AgentType | str, + key: str, + lazy: bool, + instance_getter: Callable[[AgentId], Awaitable[Agent]], +) -> AgentId: + if isinstance(id_or_type, AgentId): + if not lazy: + await instance_getter(id_or_type) + + return id_or_type + + type_str = id_or_type if isinstance(id_or_type, str) else id_or_type.type + id = AgentId(type_str, key) + if not lazy: + await instance_getter(id) + + return id + + +class SubscriptionManager: + def __init__(self) -> None: + self._subscriptions: List[Subscription] = [] + self._seen_topics: Set[TopicId] = set() + self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list) + + @property + def subscriptions(self) -> Sequence[Subscription]: + return self._subscriptions + + async def add_subscription(self, subscription: Subscription) -> None: + # Check if the subscription already exists + if any(sub == subscription for sub in self._subscriptions): + raise ValueError("Subscription already exists") + + self._subscriptions.append(subscription) + self._rebuild_subscriptions(self._seen_topics) + + async def remove_subscription(self, id: str) -> None: + # Check if the subscription exists + if not any(sub.id == id for sub in self._subscriptions): + raise ValueError("Subscription does not exist") + + def is_not_sub(x: Subscription) -> bool: + return x.id != id + + self._subscriptions = list(filter(is_not_sub, self._subscriptions)) + + # Rebuild the subscriptions + self._rebuild_subscriptions(self._seen_topics) + + async def get_subscribed_recipients(self, topic: TopicId) -> List[AgentId]: + if topic not in self._seen_topics: + self._build_for_new_topic(topic) + return self._subscribed_recipients[topic] + + # TODO: optimize this... + def _rebuild_subscriptions(self, topics: Set[TopicId]) -> None: + self._subscribed_recipients.clear() + for topic in topics: + self._build_for_new_topic(topic) + + def _build_for_new_topic(self, topic: TopicId) -> None: + self._seen_topics.add(topic) + for subscription in self._subscriptions: + if subscription.is_match(topic): + self._subscribed_recipients[topic].append(subscription.map_to_agent(topic)) diff --git a/agent_dhal/agentdhal_core/_serialization.py b/agent_dhal/agentdhal_core/_serialization.py new file mode 100644 index 0000000..5ac5c50 --- /dev/null +++ b/agent_dhal/agentdhal_core/_serialization.py @@ -0,0 +1,258 @@ +import json +from dataclasses import asdict, dataclass, fields +from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable + +from google.protobuf import any_pb2 +from google.protobuf.message import Message +from pydantic import BaseModel + +from ._type_helpers import is_union + +T = TypeVar("T") + + +class MessageSerializer(Protocol[T]): + @property + def data_content_type(self) -> str: ... + + @property + def type_name(self) -> str: ... + + def deserialize(self, payload: bytes) -> T: ... + + def serialize(self, message: T) -> bytes: ... + + +@runtime_checkable +class IsDataclass(Protocol): + # as already noted in comments, checking for this attribute is currently + # the most reliable way to ascertain that something is a dataclass + __dataclass_fields__: ClassVar[Dict[str, Any]] + + +def is_dataclass(cls: type[Any]) -> bool: + return hasattr(cls, "__dataclass_fields__") + + +def has_nested_dataclass(cls: type[IsDataclass]) -> bool: + # iterate fields and check if any of them are dataclasses + return any(is_dataclass(f.type) for f in cls.__dataclass_fields__.values()) + + +def contains_a_union(cls: type[IsDataclass]) -> bool: + return any(is_union(f.type) for f in cls.__dataclass_fields__.values()) + + +def has_nested_base_model(cls: type[IsDataclass]) -> bool: + for f in fields(cls): + field_type = f.type + # Resolve forward references and other annotations + origin = get_origin(field_type) + args = get_args(field_type) + + # If the field type is directly a subclass of BaseModel + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + return True + + # If the field type is a generic type like List[BaseModel], Tuple[BaseModel, ...], etc. + if origin is not None and args: + for arg in args: + # Recursively check the argument types + if isinstance(arg, type) and issubclass(arg, BaseModel): + return True + elif get_origin(arg) is not None: + # Handle nested generics like List[List[BaseModel]] + if has_nested_base_model_in_type(arg): + return True + # Handle Union types + elif args: + for arg in args: + if isinstance(arg, type) and issubclass(arg, BaseModel): + return True + elif get_origin(arg) is not None: + if has_nested_base_model_in_type(arg): + return True + return False + + +def has_nested_base_model_in_type(tp: Any) -> bool: + """Helper function to check if a type or its arguments is a BaseModel subclass.""" + origin = get_origin(tp) + args = get_args(tp) + + if isinstance(tp, type) and issubclass(tp, BaseModel): + return True + if origin is not None and args: + for arg in args: + if has_nested_base_model_in_type(arg): + return True + return False + + +DataclassT = TypeVar("DataclassT", bound=IsDataclass) + +JSON_DATA_CONTENT_TYPE = "application/json" +"""JSON data content type""" + +# TODO: what's the correct content type? There seems to be some disagreement over what it should be +PROTOBUF_DATA_CONTENT_TYPE = "application/x-protobuf" +"""Protobuf data content type""" + + +class DataclassJsonMessageSerializer(MessageSerializer[DataclassT]): + def __init__(self, cls: type[DataclassT]) -> None: + if contains_a_union(cls): + raise ValueError("Dataclass has a union type, which is not supported. To use a union, use a Pydantic model") + + if has_nested_dataclass(cls) or has_nested_base_model(cls): + raise ValueError( + "Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model" + ) + + self.cls = cls + + @property + def data_content_type(self) -> str: + return JSON_DATA_CONTENT_TYPE + + @property + def type_name(self) -> str: + return _type_name(self.cls) + + def deserialize(self, payload: bytes) -> DataclassT: + message_str = payload.decode("utf-8") + return self.cls(**json.loads(message_str)) + + def serialize(self, message: DataclassT) -> bytes: + return json.dumps(asdict(message)).encode("utf-8") + + +PydanticT = TypeVar("PydanticT", bound=BaseModel) + + +class PydanticJsonMessageSerializer(MessageSerializer[PydanticT]): + def __init__(self, cls: type[PydanticT]) -> None: + self.cls = cls + + @property + def data_content_type(self) -> str: + return JSON_DATA_CONTENT_TYPE + + @property + def type_name(self) -> str: + return _type_name(self.cls) + + def deserialize(self, payload: bytes) -> PydanticT: + message_str = payload.decode("utf-8") + return self.cls.model_validate_json(message_str) + + def serialize(self, message: PydanticT) -> bytes: + return message.model_dump_json().encode("utf-8") + + +ProtobufT = TypeVar("ProtobufT", bound=Message) + + +# This class serializes to and from a google.protobuf.Any message that has been serialized to a string +class ProtobufMessageSerializer(MessageSerializer[ProtobufT]): + def __init__(self, cls: type[ProtobufT]) -> None: + self.cls = cls + + @property + def data_content_type(self) -> str: + return PROTOBUF_DATA_CONTENT_TYPE + + @property + def type_name(self) -> str: + return _type_name(self.cls) + + def deserialize(self, payload: bytes) -> ProtobufT: + # Parse payload into a proto any + any_proto = any_pb2.Any() + any_proto.ParseFromString(payload) + + destination_message = self.cls() + + if not any_proto.Unpack(destination_message): # type: ignore + raise ValueError(f"Failed to unpack payload into {self.cls}") + + return destination_message + + def serialize(self, message: ProtobufT) -> bytes: + any_proto = any_pb2.Any() + any_proto.Pack(message) # type: ignore + return any_proto.SerializeToString() + + +@dataclass +class UnknownPayload: + type_name: str + data_content_type: str + payload: bytes + + +def _type_name(cls: type[Any] | Any) -> str: + # If cls is a protobuf, then we need to determine the descriptor + if isinstance(cls, type): + if issubclass(cls, Message): + return cast(str, cls.DESCRIPTOR.full_name) + elif isinstance(cls, Message): + return cast(str, cls.DESCRIPTOR.full_name) + + if isinstance(cls, type): + return cls.__name__ + else: + return cast(str, cls.__class__.__name__) + + +V = TypeVar("V") + + +def try_get_known_serializers_for_type(cls: type[Any]) -> list[MessageSerializer[Any]]: + """:meta private:""" + + serializers: List[MessageSerializer[Any]] = [] + if issubclass(cls, BaseModel): + serializers.append(PydanticJsonMessageSerializer(cls)) + elif is_dataclass(cls): + serializers.append(DataclassJsonMessageSerializer(cls)) + elif issubclass(cls, Message): + serializers.append(ProtobufMessageSerializer(cls)) + + return serializers + + +class SerializationRegistry: + """:meta private:""" + + def __init__(self) -> None: + # type_name, data_content_type -> serializer + self._serializers: dict[tuple[str, str], MessageSerializer[Any]] = {} + + def add_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None: + if isinstance(serializer, Sequence): + for c in serializer: + self.add_serializer(c) + return + + self._serializers[(serializer.type_name, serializer.data_content_type)] = serializer + + def deserialize(self, payload: bytes, *, type_name: str, data_content_type: str) -> Any: + serializer = self._serializers.get((type_name, data_content_type)) + if serializer is None: + return UnknownPayload(type_name, data_content_type, payload) + + return serializer.deserialize(payload) + + def serialize(self, message: Any, *, type_name: str, data_content_type: str) -> bytes: + serializer = self._serializers.get((type_name, data_content_type)) + if serializer is None: + raise ValueError(f"Unknown type {type_name} with content type {data_content_type}") + + return serializer.serialize(message) + + def is_registered(self, type_name: str, data_content_type: str) -> bool: + return (type_name, data_content_type) in self._serializers + + def type_name(self, message: Any) -> str: + return _type_name(message) diff --git a/agent_dhal/agentdhal_core/_single_threaded_agent_runtime.py b/agent_dhal/agentdhal_core/_single_threaded_agent_runtime.py new file mode 100644 index 0000000..ffa042e --- /dev/null +++ b/agent_dhal/agentdhal_core/_single_threaded_agent_runtime.py @@ -0,0 +1,1029 @@ +from __future__ import annotations + +import asyncio +import inspect +import json +import logging +import sys +import uuid +import warnings +from asyncio import CancelledError, Future, Queue, Task +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast + +from opentelemetry.trace import TracerProvider + +from .logging import ( + AgentConstructionExceptionEvent, + DeliveryStage, + MessageDroppedEvent, + MessageEvent, + MessageHandlerExceptionEvent, + MessageKind, +) + +if sys.version_info >= (3, 13): + from asyncio import Queue, QueueShutDown +else: + from ._queue import Queue, QueueShutDown # type: ignore + + +from ._agent import Agent +from ._agent_id import AgentId +from ._agent_instantiation import AgentInstantiationContext +from ._agent_metadata import AgentMetadata +from ._agent_runtime import AgentRuntime +from ._agent_type import AgentType +from ._cancellation_token import CancellationToken +from ._intervention import DropMessage, InterventionHandler +from ._message_context import MessageContext +from ._message_handler_context import MessageHandlerContext +from ._runtime_impl_helpers import SubscriptionManager, get_impl +from ._serialization import JSON_DATA_CONTENT_TYPE, MessageSerializer, SerializationRegistry +from ._subscription import Subscription +from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata +from ._topic import TopicId +from .exceptions import MessageDroppedException + +logger = logging.getLogger("agentdhal_core") +event_logger = logging.getLogger("agentdhal_core.events") + +# We use a type parameter in some functions which shadows the built-in `type` function. +# This is a workaround to avoid shadowing the built-in `type` function. +type_func_alias = type + + +@dataclass(kw_only=True) +class PublishMessageEnvelope: + """A message envelope for publishing messages to all agents that can handle + the message of the type T.""" + + message: Any + cancellation_token: CancellationToken + sender: AgentId | None + topic_id: TopicId + metadata: EnvelopeMetadata | None = None + message_id: str + + +@dataclass(kw_only=True) +class SendMessageEnvelope: + """A message envelope for sending a message to a specific agent that can handle + the message of the type T.""" + + message: Any + sender: AgentId | None + recipient: AgentId + future: Future[Any] + cancellation_token: CancellationToken + metadata: EnvelopeMetadata | None = None + message_id: str + + +@dataclass(kw_only=True) +class ResponseMessageEnvelope: + """A message envelope for sending a response to a message.""" + + message: Any + future: Future[Any] + sender: AgentId + recipient: AgentId | None + metadata: EnvelopeMetadata | None = None + + +P = ParamSpec("P") +T = TypeVar("T", bound=Agent) + + +class RunContext: + def __init__(self, runtime: SingleThreadedAgentRuntime) -> None: + self._runtime = runtime + self._run_task = asyncio.create_task(self._run()) + self._stopped = asyncio.Event() + + async def _run(self) -> None: + while True: + if self._stopped.is_set(): + return + + await self._runtime._process_next() # type: ignore + + async def stop(self) -> None: + self._stopped.set() + self._runtime._message_queue.shutdown(immediate=True) # type: ignore + await self._run_task + + async def stop_when_idle(self) -> None: + await self._runtime._message_queue.join() # type: ignore + self._stopped.set() + self._runtime._message_queue.shutdown(immediate=True) # type: ignore + await self._run_task + + async def stop_when(self, condition: Callable[[], bool], check_period: float = 1.0) -> None: + async def check_condition() -> None: + while not condition(): + await asyncio.sleep(check_period) + await self.stop() + + await asyncio.create_task(check_condition()) + + +def _warn_if_none(value: Any, handler_name: str) -> None: + """ + Utility function to check if the intervention handler returned None and issue a warning. + + Args: + value: The return value to check + handler_name: Name of the intervention handler method for the warning message + """ + if value is None: + warnings.warn( + f"Intervention handler {handler_name} returned None. This might be unintentional. " + "Consider returning the original message or DropMessage explicitly.", + RuntimeWarning, + stacklevel=2, + ) + + +class SingleThreadedAgentRuntime(AgentRuntime): + """A single-threaded agent runtime that processes all messages using a single asyncio queue. + Messages are delivered in the order they are received, and the runtime processes + each message in a separate asyncio task concurrently. + + .. note:: + + This runtime is suitable for development and standalone applications. + It is not suitable for high-throughput or high-concurrency scenarios. + + Args: + intervention_handlers (List[InterventionHandler], optional): A list of intervention + handlers that can intercept messages before they are sent or published. Defaults to None. + tracer_provider (TracerProvider, optional): The tracer provider to use for tracing. Defaults to None. + Additionally, you can set environment variable `AUTOGEN_DISABLE_RUNTIME_TRACING` to `true` to disable the agent runtime telemetry if you don't have access to the runtime constructor. For example, if you are using `ComponentConfig`. + ignore_unhandled_exceptions (bool, optional): Whether to ignore unhandled exceptions in that occur in agent event handlers. Any background exceptions will be raised on the next call to `process_next` or from an awaited `stop`, `stop_when_idle` or `stop_when`. Note, this does not apply to RPC handlers. Defaults to True. + + Examples: + + A simple example of creating a runtime, registering an agent, sending a message and stopping the runtime: + + .. code-block:: python + + import asyncio + from dataclasses import dataclass + + from agentdhal_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler + + + @dataclass + class MyMessage: + content: str + + + class MyAgent(RoutedAgent): + @message_handler + async def handle_my_message(self, message: MyMessage, ctx: MessageContext) -> None: + print(f"Received message: {message.content}") + + + async def main() -> None: + # Create a runtime and register the agent + runtime = SingleThreadedAgentRuntime() + await MyAgent.register(runtime, "my_agent", lambda: MyAgent("My agent")) + + # Start the runtime, send a message and stop the runtime + runtime.start() + await runtime.send_message(MyMessage("Hello, world!"), recipient=AgentId("my_agent", "default")) + await runtime.stop() + + + asyncio.run(main()) + + An example of creating a runtime, registering an agent, publishing a message and stopping the runtime: + + .. code-block:: python + + import asyncio + from dataclasses import dataclass + + from agentdhal_core import ( + DefaultTopicId, + MessageContext, + RoutedAgent, + SingleThreadedAgentRuntime, + default_subscription, + message_handler, + ) + + + @dataclass + class MyMessage: + content: str + + + # The agent is subscribed to the default topic. + @default_subscription + class MyAgent(RoutedAgent): + @message_handler + async def handle_my_message(self, message: MyMessage, ctx: MessageContext) -> None: + print(f"Received message: {message.content}") + + + async def main() -> None: + # Create a runtime and register the agent + runtime = SingleThreadedAgentRuntime() + await MyAgent.register(runtime, "my_agent", lambda: MyAgent("My agent")) + + # Start the runtime. + runtime.start() + # Publish a message to the default topic that the agent is subscribed to. + await runtime.publish_message(MyMessage("Hello, world!"), DefaultTopicId()) + # Wait for the message to be processed and then stop the runtime. + await runtime.stop_when_idle() + + + asyncio.run(main()) + + """ + + def __init__( + self, + *, + intervention_handlers: List[InterventionHandler] | None = None, + tracer_provider: TracerProvider | None = None, + ignore_unhandled_exceptions: bool = True, + ) -> None: + self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime")) + self._message_queue: Queue[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = Queue() + # (namespace, type) -> List[AgentId] + self._agent_factories: Dict[ + str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]] + ] = {} + self._instantiated_agents: Dict[AgentId, Agent] = {} + self._intervention_handlers = intervention_handlers + self._background_tasks: Set[Task[Any]] = set() + self._subscription_manager = SubscriptionManager() + self._run_context: RunContext | None = None + self._serialization_registry = SerializationRegistry() + self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions + self._background_exception: BaseException | None = None + self._agent_instance_types: Dict[str, Type[Agent]] = {} + + @property + def unprocessed_messages_count( + self, + ) -> int: + return self._message_queue.qsize() + + @property + def _known_agent_names(self) -> Set[str]: + return set(self._agent_factories.keys()) + + async def _create_otel_attributes( + self, + sender_agent_id: AgentId | None = None, + recipient_agent_id: AgentId | None = None, + message_context: MessageContext | None = None, + message: Any = None, + ) -> Mapping[str, str]: + """Create OpenTelemetry attributes for the given agent and message. + + Args: + sender_agent (Agent, optional): The sender agent instance. + recipient_agent (Agent, optional): The recipient agent instance. + message (Any): The message instance. + + Returns: + Attributes: A dictionary of OpenTelemetry attributes. + """ + if not sender_agent_id and not recipient_agent_id and not message: + return {} + attributes: Dict[str, str] = {} + if sender_agent_id: + sender_agent = await self._get_agent(sender_agent_id) + attributes["sender_agent_type"] = sender_agent.id.type + attributes["sender_agent_class"] = sender_agent.__class__.__name__ + if recipient_agent_id: + recipient_agent = await self._get_agent(recipient_agent_id) + attributes["recipient_agent_type"] = recipient_agent.id.type + attributes["recipient_agent_class"] = recipient_agent.__class__.__name__ + + if message_context: + serialized_message_context = { + "sender": str(message_context.sender), + "topic_id": str(message_context.topic_id), + "is_rpc": message_context.is_rpc, + "message_id": message_context.message_id, + } + attributes["message_context"] = json.dumps(serialized_message_context) + + if message: + try: + serialized_message = self._try_serialize(message) + except Exception as e: + serialized_message = str(e) + else: + serialized_message = "No Message" + attributes["message"] = serialized_message + + return attributes + + # Returns the response of the message + async def send_message( + self, + message: Any, + recipient: AgentId, + *, + sender: AgentId | None = None, + cancellation_token: CancellationToken | None = None, + message_id: str | None = None, + ) -> Any: + if cancellation_token is None: + cancellation_token = CancellationToken() + + if message_id is None: + message_id = str(uuid.uuid4()) + + event_logger.info( + MessageEvent( + payload=self._try_serialize(message), + sender=sender, + receiver=recipient, + kind=MessageKind.DIRECT, + delivery_stage=DeliveryStage.SEND, + ) + ) + + with self._tracer_helper.trace_block( + "create", + recipient, + parent=None, + extraAttributes={"message_type": type(message).__name__}, + ): + future = asyncio.get_event_loop().create_future() + if recipient.type not in self._known_agent_names: + future.set_exception(Exception("Recipient not found")) + return await future + + content = message.__dict__ if hasattr(message, "__dict__") else message + logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}") + + await self._message_queue.put( + SendMessageEnvelope( + message=message, + recipient=recipient, + future=future, + cancellation_token=cancellation_token, + sender=sender, + metadata=get_telemetry_envelope_metadata(), + message_id=message_id, + ) + ) + + cancellation_token.link_future(future) + + return await future + + async def publish_message( + self, + message: Any, + topic_id: TopicId, + *, + sender: AgentId | None = None, + cancellation_token: CancellationToken | None = None, + message_id: str | None = None, + ) -> None: + with self._tracer_helper.trace_block( + "create", + topic_id, + parent=None, + extraAttributes={"message_type": type(message).__name__}, + ): + if cancellation_token is None: + cancellation_token = CancellationToken() + content = message.__dict__ if hasattr(message, "__dict__") else message + logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {content}") + + if message_id is None: + message_id = str(uuid.uuid4()) + + event_logger.info( + MessageEvent( + payload=self._try_serialize(message), + sender=sender, + receiver=topic_id, + kind=MessageKind.PUBLISH, + delivery_stage=DeliveryStage.SEND, + ) + ) + + await self._message_queue.put( + PublishMessageEnvelope( + message=message, + cancellation_token=cancellation_token, + sender=sender, + topic_id=topic_id, + metadata=get_telemetry_envelope_metadata(), + message_id=message_id, + ) + ) + + async def save_state(self) -> Mapping[str, Any]: + """Save the state of all instantiated agents. + + This method calls the :meth:`~agentdhal_core.BaseAgent.save_state` method on each agent and returns a dictionary + mapping agent IDs to their state. + + .. note:: + This method does not currently save the subscription state. We will add this in the future. + + Returns: + A dictionary mapping agent IDs to their state. + + """ + state: Dict[str, Dict[str, Any]] = {} + for agent_id in self._instantiated_agents: + state[str(agent_id)] = dict(await (await self._get_agent(agent_id)).save_state()) + return state + + async def load_state(self, state: Mapping[str, Any]) -> None: + """Load the state of all instantiated agents. + + This method calls the :meth:`~agentdhal_core.BaseAgent.load_state` method on each agent with the state + provided in the dictionary. The keys of the dictionary are the agent IDs, and the values are the state + dictionaries returned by the :meth:`~agentdhal_core.BaseAgent.save_state` method. + + .. note:: + + This method does not currently load the subscription state. We will add this in the future. + + """ + for agent_id_str in state: + agent_id = AgentId.from_str(agent_id_str) + if agent_id.type in self._known_agent_names: + await (await self._get_agent(agent_id)).load_state(state[str(agent_id)]) + + async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: + with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata): + recipient = message_envelope.recipient + + if recipient.type not in self._known_agent_names: + raise LookupError(f"Agent type '{recipient.type}' does not exist.") + + try: + sender_id = str(message_envelope.sender) if message_envelope.sender is not None else "Unknown" + logger.info( + f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_id}" + ) + event_logger.info( + MessageEvent( + payload=self._try_serialize(message_envelope.message), + sender=message_envelope.sender, + receiver=recipient, + kind=MessageKind.DIRECT, + delivery_stage=DeliveryStage.DELIVER, + ) + ) + recipient_agent = await self._get_agent(recipient) + + message_context = MessageContext( + sender=message_envelope.sender, + topic_id=None, + is_rpc=True, + cancellation_token=message_envelope.cancellation_token, + message_id=message_envelope.message_id, + ) + with self._tracer_helper.trace_block( + "process", + recipient_agent.id, + parent=message_envelope.metadata, + attributes=await self._create_otel_attributes( + sender_agent_id=message_envelope.sender, + recipient_agent_id=recipient, + message_context=message_context, + message=message_envelope.message, + ), + ): + with MessageHandlerContext.populate_context(recipient_agent.id): + response = await recipient_agent.on_message( + message_envelope.message, + ctx=message_context, + ) + except CancelledError as e: + if not message_envelope.future.cancelled(): + message_envelope.future.set_exception(e) + self._message_queue.task_done() + event_logger.info( + MessageHandlerExceptionEvent( + payload=self._try_serialize(message_envelope.message), + handling_agent=recipient, + exception=e, + ) + ) + return + except BaseException as e: + message_envelope.future.set_exception(e) + self._message_queue.task_done() + event_logger.info( + MessageHandlerExceptionEvent( + payload=self._try_serialize(message_envelope.message), + handling_agent=recipient, + exception=e, + ) + ) + return + + event_logger.info( + MessageEvent( + payload=self._try_serialize(response), + sender=message_envelope.recipient, + receiver=message_envelope.sender, + kind=MessageKind.RESPOND, + delivery_stage=DeliveryStage.SEND, + ) + ) + + await self._message_queue.put( + ResponseMessageEnvelope( + message=response, + future=message_envelope.future, + sender=message_envelope.recipient, + recipient=message_envelope.sender, + metadata=get_telemetry_envelope_metadata(), + ) + ) + self._message_queue.task_done() + + async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None: + with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata): + try: + responses: List[Awaitable[Any]] = [] + recipients = await self._subscription_manager.get_subscribed_recipients(message_envelope.topic_id) + for agent_id in recipients: + # Avoid sending the message back to the sender + if message_envelope.sender is not None and agent_id == message_envelope.sender: + continue + + sender_agent = ( + await self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None + ) + sender_name = str(sender_agent.id) if sender_agent is not None else "Unknown" + logger.info( + f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}" + ) + event_logger.info( + MessageEvent( + payload=self._try_serialize(message_envelope.message), + sender=message_envelope.sender, + receiver=None, + kind=MessageKind.PUBLISH, + delivery_stage=DeliveryStage.DELIVER, + ) + ) + message_context = MessageContext( + sender=message_envelope.sender, + topic_id=message_envelope.topic_id, + is_rpc=False, + cancellation_token=message_envelope.cancellation_token, + message_id=message_envelope.message_id, + ) + agent = await self._get_agent(agent_id) + + async def _on_message(agent: Agent, message_context: MessageContext) -> Any: + with self._tracer_helper.trace_block( + "process", + agent.id, + parent=message_envelope.metadata, + attributes=await self._create_otel_attributes( + sender_agent_id=message_envelope.sender, + recipient_agent_id=agent.id, + message_context=message_context, + message=message_envelope.message, + ), + ): + with MessageHandlerContext.populate_context(agent.id): + try: + return await agent.on_message( + message_envelope.message, + ctx=message_context, + ) + except BaseException as e: + logger.error(f"Error processing publish message for {agent.id}", exc_info=True) + event_logger.info( + MessageHandlerExceptionEvent( + payload=self._try_serialize(message_envelope.message), + handling_agent=agent.id, + exception=e, + ) + ) + raise e + + future = _on_message(agent, message_context) + responses.append(future) + + await asyncio.gather(*responses) + except BaseException as e: + if not self._ignore_unhandled_handler_exceptions: + self._background_exception = e + finally: + self._message_queue.task_done() + # TODO if responses are given for a publish + + async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: + with self._tracer_helper.trace_block( + "ack", + message_envelope.recipient, + parent=message_envelope.metadata, + attributes=await self._create_otel_attributes( + sender_agent_id=message_envelope.sender, + recipient_agent_id=message_envelope.recipient, + message=message_envelope.message, + ), + ): + content = ( + message_envelope.message.__dict__ + if hasattr(message_envelope.message, "__dict__") + else message_envelope.message + ) + logger.info( + f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}" + ) + event_logger.info( + MessageEvent( + payload=self._try_serialize(message_envelope.message), + sender=message_envelope.sender, + receiver=message_envelope.recipient, + kind=MessageKind.RESPOND, + delivery_stage=DeliveryStage.DELIVER, + ) + ) + if not message_envelope.future.cancelled(): + message_envelope.future.set_result(message_envelope.message) + self._message_queue.task_done() + + async def process_next(self) -> None: + """Process the next message in the queue. + + If there is an unhandled exception in the background task, it will be raised here. `process_next` cannot be called again after an unhandled exception is raised. + """ + await self._process_next() + + async def _process_next(self) -> None: + """Process the next message in the queue.""" + + if self._background_exception is not None: + e = self._background_exception + self._background_exception = None + self._message_queue.shutdown(immediate=True) # type: ignore + raise e + + try: + message_envelope = await self._message_queue.get() + except QueueShutDown: + if self._background_exception is not None: + e = self._background_exception + self._background_exception = None + raise e from None + return + + match message_envelope: + case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): + if self._intervention_handlers is not None: + for handler in self._intervention_handlers: + with self._tracer_helper.trace_block( + "intercept", handler.__class__.__name__, parent=message_envelope.metadata + ): + try: + message_context = MessageContext( + sender=sender, + topic_id=None, + is_rpc=True, + cancellation_token=message_envelope.cancellation_token, + message_id=message_envelope.message_id, + ) + temp_message = await handler.on_send( + message, message_context=message_context, recipient=recipient + ) + _warn_if_none(temp_message, "on_send") + except BaseException as e: + future.set_exception(e) + return + if temp_message is DropMessage or isinstance(temp_message, DropMessage): + event_logger.info( + MessageDroppedEvent( + payload=self._try_serialize(message), + sender=sender, + receiver=recipient, + kind=MessageKind.DIRECT, + ) + ) + future.set_exception(MessageDroppedException()) + return + + message_envelope.message = temp_message + task = asyncio.create_task(self._process_send(message_envelope)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + case PublishMessageEnvelope( + message=message, + sender=sender, + topic_id=topic_id, + ): + if self._intervention_handlers is not None: + for handler in self._intervention_handlers: + with self._tracer_helper.trace_block( + "intercept", handler.__class__.__name__, parent=message_envelope.metadata + ): + try: + message_context = MessageContext( + sender=sender, + topic_id=topic_id, + is_rpc=False, + cancellation_token=message_envelope.cancellation_token, + message_id=message_envelope.message_id, + ) + temp_message = await handler.on_publish(message, message_context=message_context) + _warn_if_none(temp_message, "on_publish") + except BaseException as e: + # TODO: we should raise the intervention exception to the publisher. + logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True) + return + if temp_message is DropMessage or isinstance(temp_message, DropMessage): + event_logger.info( + MessageDroppedEvent( + payload=self._try_serialize(message), + sender=sender, + receiver=topic_id, + kind=MessageKind.PUBLISH, + ) + ) + return + + message_envelope.message = temp_message + + task = asyncio.create_task(self._process_publish(message_envelope)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): + if self._intervention_handlers is not None: + for handler in self._intervention_handlers: + try: + temp_message = await handler.on_response(message, sender=sender, recipient=recipient) + _warn_if_none(temp_message, "on_response") + except BaseException as e: + # TODO: should we raise the exception to sender of the response instead? + future.set_exception(e) + return + if temp_message is DropMessage or isinstance(temp_message, DropMessage): + event_logger.info( + MessageDroppedEvent( + payload=self._try_serialize(message), + sender=sender, + receiver=recipient, + kind=MessageKind.RESPOND, + ) + ) + future.set_exception(MessageDroppedException()) + return + message_envelope.message = temp_message + task = asyncio.create_task(self._process_response(message_envelope)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + # Yield control to the message loop to allow other tasks to run + await asyncio.sleep(0) + + def start(self) -> None: + """Start the runtime message processing loop. This runs in a background task. + + Example: + + .. code-block:: python + + import asyncio + from agentdhal_core import SingleThreadedAgentRuntime + + + async def main() -> None: + runtime = SingleThreadedAgentRuntime() + runtime.start() + + # ... do other things ... + + await runtime.stop() + + + asyncio.run(main()) + + """ + if self._run_context is not None: + raise RuntimeError("Runtime is already started") + self._run_context = RunContext(self) + + async def close(self) -> None: + """Calls :meth:`stop` if applicable and the :meth:`Agent.close` method on all instantiated agents""" + # stop the runtime if it hasn't been stopped yet + if self._run_context is not None: + await self.stop() + # close all the agents that have been instantiated + for agent_id in self._instantiated_agents: + agent = await self._get_agent(agent_id) + await agent.close() + + async def stop(self) -> None: + """Immediately stop the runtime message processing loop. The currently processing message will be completed, but all others following it will be discarded.""" + if self._run_context is None: + raise RuntimeError("Runtime is not started") + + try: + await self._run_context.stop() + finally: + self._run_context = None + self._message_queue = Queue() + + async def stop_when_idle(self) -> None: + """Stop the runtime message processing loop when there is + no outstanding message being processed or queued. This is the most common way to stop the runtime.""" + if self._run_context is None: + raise RuntimeError("Runtime is not started") + + try: + await self._run_context.stop_when_idle() + finally: + self._run_context = None + self._message_queue = Queue() + + async def stop_when(self, condition: Callable[[], bool]) -> None: + """Stop the runtime message processing loop when the condition is met. + + .. caution:: + + This method is not recommended to be used, and is here for legacy + reasons. It will spawn a busy loop to continually check the + condition. It is much more efficient to call `stop_when_idle` or + `stop` instead. If you need to stop the runtime based on a + condition, consider using a background task and asyncio.Event to + signal when the condition is met and the background task should call + stop. + + """ + if self._run_context is None: + raise RuntimeError("Runtime is not started") + await self._run_context.stop_when(condition) + + self._run_context = None + self._message_queue = Queue() + + async def agent_metadata(self, agent: AgentId) -> AgentMetadata: + return (await self._get_agent(agent)).metadata + + async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: + return await (await self._get_agent(agent)).save_state() + + async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: + await (await self._get_agent(agent)).load_state(state) + + async def register_factory( + self, + type: str | AgentType, + agent_factory: Callable[[], T | Awaitable[T]], + *, + expected_class: type[T] | None = None, + ) -> AgentType: + if isinstance(type, str): + type = AgentType(type) + + if type.type in self._agent_factories: + raise ValueError(f"Agent with type {type} already exists.") + + async def factory_wrapper() -> T: + maybe_agent_instance = agent_factory() + if inspect.isawaitable(maybe_agent_instance): + agent_instance = await maybe_agent_instance + else: + agent_instance = maybe_agent_instance + + if expected_class is not None and not issubclass(type_func_alias(agent_instance), expected_class): + raise ValueError( + f"Factory registered using the wrong type: expected {expected_class.__name__}, got {type_func_alias(agent_instance).__name__}" + ) + return agent_instance + + self._agent_factories[type.type] = factory_wrapper + + return type + + async def register_agent_instance( + self, + agent_instance: Agent, + agent_id: AgentId, + ) -> AgentId: + def agent_factory() -> Agent: + raise RuntimeError( + "Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent." + ) + + if agent_id in self._instantiated_agents: + raise ValueError(f"Agent with id {agent_id} already exists.") + + if agent_id.type not in self._agent_factories: + self._agent_factories[agent_id.type] = agent_factory + self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance) + else: + if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__: + raise ValueError("Agent factories and agent instances cannot be registered to the same type.") + if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance): + raise ValueError("Agent instances must be the same object type.") + + await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self) + self._instantiated_agents[agent_id] = agent_instance + return agent_id + + async def _invoke_agent_factory( + self, + agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], + agent_id: AgentId, + ) -> T: + with AgentInstantiationContext.populate_context((self, agent_id)): + try: + if len(inspect.signature(agent_factory).parameters) == 0: + factory_one = cast(Callable[[], T], agent_factory) + agent = factory_one() + elif len(inspect.signature(agent_factory).parameters) == 2: + warnings.warn( + "Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.", + stacklevel=2, + ) + factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory) + agent = factory_two(self, agent_id) + else: + raise ValueError("Agent factory must take 0 or 2 arguments.") + + if inspect.isawaitable(agent): + agent = cast(T, await agent) + return agent + + except BaseException as e: + event_logger.info( + AgentConstructionExceptionEvent( + agent_id=agent_id, + exception=e, + ) + ) + logger.error(f"Error constructing agent {agent_id}", exc_info=True) + raise + + async def _get_agent(self, agent_id: AgentId) -> Agent: + if agent_id in self._instantiated_agents: + return self._instantiated_agents[agent_id] + + if agent_id.type not in self._agent_factories: + raise LookupError(f"Agent with name {agent_id.type} not found.") + + agent_factory = self._agent_factories[agent_id.type] + agent = await self._invoke_agent_factory(agent_factory, agent_id) + self._instantiated_agents[agent_id] = agent + return agent + + # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 + async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] + if id.type not in self._agent_factories: + raise LookupError(f"Agent with name {id.type} not found.") + + # TODO: check if remote + agent_instance = await self._get_agent(id) + + if not isinstance(agent_instance, type): + raise TypeError( + f"Agent with name {id.type} is not of type {type.__name__}. It is of type {type_func_alias(agent_instance).__name__}" + ) + + return agent_instance + + async def add_subscription(self, subscription: Subscription) -> None: + await self._subscription_manager.add_subscription(subscription) + + async def remove_subscription(self, id: str) -> None: + await self._subscription_manager.remove_subscription(id) + + async def get( + self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True + ) -> AgentId: + return await get_impl( + id_or_type=id_or_type, + key=key, + lazy=lazy, + instance_getter=self._get_agent, + ) + + def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None: + self._serialization_registry.add_serializer(serializer) + + def _try_serialize(self, message: Any) -> str: + try: + type_name = self._serialization_registry.type_name(message) + return self._serialization_registry.serialize( + message, type_name=type_name, data_content_type=JSON_DATA_CONTENT_TYPE + ).decode("utf-8") + except ValueError: + return "Message could not be serialized" diff --git a/agent_dhal/agentdhal_core/_subscription.py b/agent_dhal/agentdhal_core/_subscription.py new file mode 100644 index 0000000..ddfda0d --- /dev/null +++ b/agent_dhal/agentdhal_core/_subscription.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import Awaitable, Callable, Protocol, runtime_checkable + +from ._agent_id import AgentId +from ._topic import TopicId + + +@runtime_checkable +class Subscription(Protocol): + """Subscriptions define the topics that an agent is interested in.""" + + @property + def id(self) -> str: + """Get the ID of the subscription. + + Implementations should return a unique ID for the subscription. Usually this is a UUID. + + Returns: + str: ID of the subscription. + """ + ... + + def __eq__(self, other: object) -> bool: + """Check if two subscriptions are equal. + + Args: + other (object): Other subscription to compare against. + + Returns: + bool: True if the subscriptions are equal, False otherwise. + """ + if not isinstance(other, Subscription): + return False + + return self.id == other.id + + def is_match(self, topic_id: TopicId) -> bool: + """Check if a given topic_id matches the subscription. + + Args: + topic_id (TopicId): TopicId to check. + + Returns: + bool: True if the topic_id matches the subscription, False otherwise. + """ + ... + + def map_to_agent(self, topic_id: TopicId) -> AgentId: + """Map a topic_id to an agent. Should only be called if `is_match` returns True for the given topic_id. + + Args: + topic_id (TopicId): TopicId to map. + + Returns: + AgentId: ID of the agent that should handle the topic_id. + + Raises: + CantHandleException: If the subscription cannot handle the topic_id. + """ + ... + + +# Helper alias to represent the lambdas used to define subscriptions +UnboundSubscription = Callable[[], list[Subscription] | Awaitable[list[Subscription]]] diff --git a/agent_dhal/agentdhal_core/_subscription_context.py b/agent_dhal/agentdhal_core/_subscription_context.py new file mode 100644 index 0000000..29b1e16 --- /dev/null +++ b/agent_dhal/agentdhal_core/_subscription_context.py @@ -0,0 +1,33 @@ +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, ClassVar, Generator + +from ._agent_type import AgentType + + +class SubscriptionInstantiationContext: + def __init__(self) -> None: + raise RuntimeError( + "SubscriptionInstantiationContext cannot be instantiated. It is a static class that provides context management for subscription instantiation." + ) + + _SUBSCRIPTION_CONTEXT_VAR: ClassVar[ContextVar[AgentType]] = ContextVar("_SUBSCRIPTION_CONTEXT_VAR") + + @classmethod + @contextmanager + def populate_context(cls, ctx: AgentType) -> Generator[None, Any, None]: + """:meta private:""" + token = SubscriptionInstantiationContext._SUBSCRIPTION_CONTEXT_VAR.set(ctx) + try: + yield + finally: + SubscriptionInstantiationContext._SUBSCRIPTION_CONTEXT_VAR.reset(token) + + @classmethod + def agent_type(cls) -> AgentType: + try: + return cls._SUBSCRIPTION_CONTEXT_VAR.get() + except LookupError as e: + raise RuntimeError( + "SubscriptionInstantiationContext.runtime() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so." + ) from e diff --git a/agent_dhal/agentdhal_core/_telemetry/__init__.py b/agent_dhal/agentdhal_core/_telemetry/__init__.py new file mode 100644 index 0000000..c67591a --- /dev/null +++ b/agent_dhal/agentdhal_core/_telemetry/__init__.py @@ -0,0 +1,25 @@ +from ._genai import ( + trace_create_agent_span, + trace_invoke_agent_span, + trace_tool_span, +) +from ._propagation import ( + EnvelopeMetadata, + TelemetryMetadataContainer, + get_telemetry_envelope_metadata, + get_telemetry_grpc_metadata, +) +from ._tracing import TraceHelper +from ._tracing_config import MessageRuntimeTracingConfig + +__all__ = [ + "EnvelopeMetadata", + "get_telemetry_envelope_metadata", + "get_telemetry_grpc_metadata", + "TelemetryMetadataContainer", + "TraceHelper", + "MessageRuntimeTracingConfig", + "trace_create_agent_span", + "trace_invoke_agent_span", + "trace_tool_span", +] diff --git a/agent_dhal/agentdhal_core/_telemetry/_constants.py b/agent_dhal/agentdhal_core/_telemetry/_constants.py new file mode 100644 index 0000000..6348e48 --- /dev/null +++ b/agent_dhal/agentdhal_core/_telemetry/_constants.py @@ -0,0 +1 @@ +NAMESPACE = "agentdhal" diff --git a/agent_dhal/agentdhal_core/_telemetry/_genai.py b/agent_dhal/agentdhal_core/_telemetry/_genai.py new file mode 100644 index 0000000..aaecec7 --- /dev/null +++ b/agent_dhal/agentdhal_core/_telemetry/_genai.py @@ -0,0 +1,214 @@ +from collections.abc import Generator +from contextlib import contextmanager +from enum import Enum +from typing import Any, Optional + +from opentelemetry import trace +from opentelemetry.trace import Span, SpanKind + +from .._agent_instantiation import AgentInstantiationContext + +# OpenTelemetry semantic convention constants for GenAI operations +# Copied from opentelemetry-semantic-conventions to avoid dependency + +# GenAI Agent attributes +GEN_AI_AGENT_DESCRIPTION = "gen_ai.agent.description" +GEN_AI_AGENT_ID = "gen_ai.agent.id" +GEN_AI_AGENT_NAME = "gen_ai.agent.name" + +# GenAI Operation attributes +GEN_AI_OPERATION_NAME = "gen_ai.operation.name" +GEN_AI_SYSTEM = "gen_ai.system" + +# GenAI Tool attributes +GEN_AI_TOOL_CALL_ID = "gen_ai.tool.call.id" +GEN_AI_TOOL_DESCRIPTION = "gen_ai.tool.description" +GEN_AI_TOOL_NAME = "gen_ai.tool.name" + +# Error attributes +ERROR_TYPE = "error.type" + + +class GenAiOperationNameValues(Enum): + """Enum for GenAI operation name values.""" + + CHAT = "chat" + CREATE_AGENT = "create_agent" + EMBEDDINGS = "embeddings" + EXECUTE_TOOL = "execute_tool" + GENERATE_CONTENT = "generate_content" + INVOKE_AGENT = "invoke_agent" + TEXT_COMPLETION = "text_completion" + + +# Constant for system name +GENAI_SYSTEM_AUTOGEN = "agentdhal" + + +@contextmanager +def trace_tool_span( + tool_name: str, + *, + tracer: Optional[trace.Tracer] = None, + parent: Optional[Span] = None, + tool_description: Optional[str] = None, + tool_call_id: Optional[str] = None, +) -> Generator[Span, Any, None]: + """Context manager to create a span for tool execution following the + OpenTelemetry Semantic conventions for generative AI systems. + + See the GenAI semantic conventions documentation: + `OpenTelemetry GenAI Semantic Conventions `__ + + .. warning:: + + The GenAI Semantic Conventions are still in incubation and + subject to changes in future releases. + + + Args: + tool_name (str): The name of the tool being executed. + tracer (Optional[trace.Tracer]): The tracer to use for creating the span. + parent (Optional[Span]): The parent span to link this span to. + tool_description (Optional[str]): A description of the tool. + tool_call_id (Optional[str]): A unique identifier for the tool call. + """ + if tracer is None: + tracer = trace.get_tracer("agentdhal-core") + span_attributes = { + GEN_AI_OPERATION_NAME: GenAiOperationNameValues.EXECUTE_TOOL.value, + GEN_AI_SYSTEM: GENAI_SYSTEM_AUTOGEN, + GEN_AI_TOOL_NAME: tool_name, + } + if tool_description is not None: + span_attributes[GEN_AI_TOOL_DESCRIPTION] = tool_description + if tool_call_id is not None: + span_attributes[GEN_AI_TOOL_CALL_ID] = tool_call_id + with tracer.start_as_current_span( + f"{GenAiOperationNameValues.EXECUTE_TOOL.value} {tool_name}", + kind=SpanKind.INTERNAL, + context=trace.set_span_in_context(parent) if parent else None, + attributes=span_attributes, + ) as span: + try: + yield span + except Exception as e: + # Set the exception details on the span if an error occurs + span.record_exception(e) + span.set_status(trace.Status(trace.StatusCode.ERROR, str(e))) + span.set_attribute(ERROR_TYPE, type(e).__name__) + raise + + +@contextmanager +def trace_create_agent_span( + agent_name: str, + *, + tracer: Optional[trace.Tracer] = None, + parent: Optional[Span] = None, + agent_id: Optional[str] = None, + agent_description: Optional[str] = None, +) -> Generator[Span, Any, None]: + """Context manager to create a span for agent creation following the + OpenTelemetry Semantic conventions for generative AI systems. + + See the GenAI semantic conventions documentation: + `OpenTelemetry GenAI Semantic Conventions `__ + + .. warning:: + + The GenAI Semantic Conventions are still in incubation and + subject to changes in future releases. + + Args: + agent_name (str): The name of the agent being created. + tracer (Optional[trace.Tracer]): The tracer to use for creating the span. + parent (Optional[Span]): The parent span to link this span to. + agent_id (Optional[str]): The unique identifier for the agent. + agent_description (Optional[str]): A description of the agent. + """ + if tracer is None: + tracer = trace.get_tracer("agentdhal-core") + span_attributes = { + GEN_AI_OPERATION_NAME: GenAiOperationNameValues.CREATE_AGENT.value, + GEN_AI_SYSTEM: GENAI_SYSTEM_AUTOGEN, + GEN_AI_AGENT_NAME: agent_name, + } + if agent_id is None: + # Try to see if we can get the agent ID from the current context + try: + agent_id = str(AgentInstantiationContext.current_agent_id()) + except RuntimeError: + agent_id = None + if agent_id is not None: + span_attributes[GEN_AI_AGENT_ID] = agent_id + if agent_description is not None: + span_attributes[GEN_AI_AGENT_DESCRIPTION] = agent_description + with tracer.start_as_current_span( + f"{GenAiOperationNameValues.CREATE_AGENT.value} {agent_name}", + kind=SpanKind.CLIENT, + context=trace.set_span_in_context(parent) if parent else None, + attributes=span_attributes, + ) as span: + try: + yield span + except Exception as e: + # Set the exception details on the span if an error occurs + span.record_exception(e) + span.set_status(trace.Status(trace.StatusCode.ERROR, str(e))) + span.set_attribute(ERROR_TYPE, type(e).__name__) + raise + + +@contextmanager +def trace_invoke_agent_span( + agent_name: str, + *, + tracer: Optional[trace.Tracer] = None, + parent: Optional[Span] = None, + agent_id: Optional[str] = None, + agent_description: Optional[str] = None, +) -> Generator[Span, Any, None]: + """Context manager to create a span for invoking an agent following the + OpenTelemetry Semantic conventions for generative AI systems. + + See the GenAI semantic conventions documentation: + `OpenTelemetry GenAI Semantic Conventions `__ + + .. warning:: + + The GenAI Semantic Conventions are still in incubation and + subject to changes in future releases. + + Args: + agent_name (str): The name of the agent being invoked. + tracer (Optional[trace.Tracer]): The tracer to use for creating the span. + parent (Optional[Span]): The parent span to link this span to. + agent_id (Optional[str]): The unique identifier for the agent. + agent_description (Optional[str]): A description of the agent. + """ + if tracer is None: + tracer = trace.get_tracer("agentdhal-core") + span_attributes = { + GEN_AI_OPERATION_NAME: GenAiOperationNameValues.INVOKE_AGENT.value, + GEN_AI_SYSTEM: GENAI_SYSTEM_AUTOGEN, + GEN_AI_AGENT_NAME: agent_name, + } + if agent_id is not None: + span_attributes[GEN_AI_AGENT_ID] = agent_id + if agent_description is not None: + span_attributes[GEN_AI_AGENT_DESCRIPTION] = agent_description + with tracer.start_as_current_span( + f"{GenAiOperationNameValues.INVOKE_AGENT.value} {agent_name}", + kind=SpanKind.CLIENT, + context=trace.set_span_in_context(parent) if parent else None, + attributes=span_attributes, + ) as span: + try: + yield span + except Exception as e: + # Set the exception details on the span if an error occurs + span.record_exception(e) + span.set_status(trace.Status(trace.StatusCode.ERROR, str(e))) + span.set_attribute(ERROR_TYPE, type(e).__name__) + raise diff --git a/agent_dhal/agentdhal_core/_telemetry/_propagation.py b/agent_dhal/agentdhal_core/_telemetry/_propagation.py new file mode 100644 index 0000000..a3834e1 --- /dev/null +++ b/agent_dhal/agentdhal_core/_telemetry/_propagation.py @@ -0,0 +1,127 @@ +from dataclasses import dataclass +from typing import Dict, Mapping, Optional, Sequence + +from opentelemetry.context import Context +from opentelemetry.propagate import extract +from opentelemetry.trace import Link, get_current_span +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + +@dataclass(kw_only=True) +class EnvelopeMetadata: + """Metadata for an envelope.""" + + traceparent: Optional[str] = None + tracestate: Optional[str] = None + links: Optional[Sequence[Link]] = None + + +def _get_carrier_for_envelope_metadata(envelope_metadata: EnvelopeMetadata) -> Dict[str, str]: + carrier: Dict[str, str] = {} + if envelope_metadata.traceparent is not None: + carrier["traceparent"] = envelope_metadata.traceparent + if envelope_metadata.tracestate is not None: + carrier["tracestate"] = envelope_metadata.tracestate + return carrier + + +def get_telemetry_envelope_metadata() -> EnvelopeMetadata: + """ + Retrieves the telemetry envelope metadata. + + Returns: + EnvelopeMetadata: The envelope metadata containing the traceparent and tracestate. + """ + carrier: Dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier) + return EnvelopeMetadata( + traceparent=carrier.get("traceparent"), + tracestate=carrier.get("tracestate"), + ) + + +def _get_carrier_for_remote_call_metadata(remote_call_metadata: Mapping[str, str]) -> Dict[str, str]: + carrier: Dict[str, str] = {} + traceparent = remote_call_metadata.get("traceparent") + tracestate = remote_call_metadata.get("tracestate") + if traceparent: + carrier["traceparent"] = traceparent + if tracestate: + carrier["tracestate"] = tracestate + return carrier + + +def get_telemetry_grpc_metadata(existingMetadata: Optional[Mapping[str, str]] = None) -> Dict[str, str]: + """ + Retrieves the telemetry gRPC metadata. + + Args: + existingMetadata (Optional[Mapping[str, str]]): The existing metadata to include in the gRPC metadata. + + Returns: + Mapping[str, str]: The gRPC metadata containing the traceparent and tracestate. + """ + carrier: Dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier) + traceparent = carrier.get("traceparent") + tracestate = carrier.get("tracestate") + metadata: Dict[str, str] = {} + if existingMetadata is not None: + for key, value in existingMetadata.items(): + metadata[key] = value + if traceparent is not None: + metadata["traceparent"] = traceparent + if tracestate is not None: + metadata["tracestate"] = tracestate + return metadata + + +TelemetryMetadataContainer = Optional[EnvelopeMetadata] | Mapping[str, str] + + +def get_telemetry_context(metadata: TelemetryMetadataContainer) -> Context: + """ + Retrieves the telemetry context from the given metadata. + + Args: + metadata (Optional[EnvelopeMetadata]): The metadata containing the telemetry context. + + Returns: + Context: The telemetry context extracted from the metadata, or an empty context if the metadata is None. + """ + if metadata is None: + return Context() + elif isinstance(metadata, EnvelopeMetadata): + return extract(_get_carrier_for_envelope_metadata(metadata)) + elif hasattr(metadata, "__getitem__"): + return extract(_get_carrier_for_remote_call_metadata(metadata)) + else: + raise ValueError(f"Unknown metadata type: {type(metadata)}") + + +def get_telemetry_links( + metadata: TelemetryMetadataContainer, +) -> Optional[Sequence[Link]]: + """ + Retrieves the telemetry links from the given metadata. + + Args: + metadata (Optional[EnvelopeMetadata]): The metadata containing the telemetry links. + + Returns: + Optional[Sequence[Link]]: The telemetry links extracted from the metadata, or None if there are no links. + """ + if metadata is None: + return None + elif isinstance(metadata, EnvelopeMetadata): + context = extract(_get_carrier_for_envelope_metadata(metadata)) + elif hasattr(metadata, "__getitem__"): + context = extract(_get_carrier_for_remote_call_metadata(metadata)) + else: + return None + # Retrieve the extracted SpanContext from the context. + linked_span = get_current_span(context) + # Use the linked span to get the SpanContext. + span_context = linked_span.get_span_context() + # Create a Link object using the SpanContext. + return [Link(span_context)] diff --git a/agent_dhal/agentdhal_core/_telemetry/_tracing.py b/agent_dhal/agentdhal_core/_telemetry/_tracing.py new file mode 100644 index 0000000..c31e647 --- /dev/null +++ b/agent_dhal/agentdhal_core/_telemetry/_tracing.py @@ -0,0 +1,99 @@ +import contextlib +import os +from typing import Dict, Generic, Iterator, Optional + +from opentelemetry.trace import NoOpTracerProvider, Span, SpanKind, TracerProvider, get_tracer_provider +from opentelemetry.util import types + +from ._propagation import TelemetryMetadataContainer, get_telemetry_links +from ._tracing_config import Destination, ExtraAttributes, Operation, TracingConfig + + +class TraceHelper(Generic[Operation, Destination, ExtraAttributes]): + """ + TraceHelper is a utility class to assist with tracing operations using OpenTelemetry. + + This class provides a context manager `trace_block` to create and manage spans for tracing operations, + following semantic conventions and supporting nested spans through metadata contexts. + + """ + + def __init__( + self, + tracer_provider: TracerProvider | None, + instrumentation_builder_config: TracingConfig[Operation, Destination, ExtraAttributes], + ) -> None: + self.instrumentation_builder_config = instrumentation_builder_config + + disable_runtime_tracing = os.environ.get("AUTOGEN_DISABLE_RUNTIME_TRACING") == "true" + if disable_runtime_tracing: + self.tracer_provider: TracerProvider = NoOpTracerProvider() + self.tracer = self.tracer_provider.get_tracer(f"agentdhal {instrumentation_builder_config.name}") + return + + # Evaluate in order: first try tracer_provider param, then get_tracer_provider(), finally fallback to NoOp + # This allows for nested tracing with a default tracer provided by the user + self.tracer_provider = tracer_provider or get_tracer_provider() or NoOpTracerProvider() + self.tracer = self.tracer_provider.get_tracer(f"agentdhal {instrumentation_builder_config.name}") + + @contextlib.contextmanager + def trace_block( + self, + operation: Operation, + destination: Destination, + parent: Optional[TelemetryMetadataContainer], + *, + extraAttributes: ExtraAttributes | None = None, + kind: Optional[SpanKind] = None, + attributes: Optional[types.Attributes] = None, + start_time: Optional[int] = None, + record_exception: bool = True, + set_status_on_exception: bool = True, + end_on_exit: bool = True, + ) -> Iterator[Span]: + """ + Thin wrapper on top of start_as_current_span. + 1. It helps us follow semantic conventions + 2. It helps us get contexts from metadata so we can get nested spans + + Args: + operation (MessagingOperation): The messaging operation being performed. + destination (MessagingDestination): The messaging destination being used. + parent Optional[TelemetryMetadataContainer]: The parent telemetry metadta context + kind (SpanKind, optional): The kind of span. If not provided, it maps to PRODUCER or CONSUMER depending on the operation. + extraAttributes (ExtraAttributes, optional): Additional defined attributes for the span. Defaults to None. + attributes (Optional[types.Attributes], optional): Additional non-defined attributes for the span. Defaults to None. + start_time (Optional[int], optional): The start time of the span. Defaults to None. + record_exception (bool, optional): Whether to record exceptions. Defaults to True. + set_status_on_exception (bool, optional): Whether to set the status on exception. Defaults to True. + end_on_exit (bool, optional): Whether to end the span on exit. Defaults to True. + + Yields: + Iterator[Span]: The span object. + + """ + span_name = self.instrumentation_builder_config.get_span_name(operation, destination) + span_kind = kind or self.instrumentation_builder_config.get_span_kind(operation) + # context = get_telemetry_context(parent) if parent else None + context = None # TODO: we may need to remove other code for using custom context. + links = get_telemetry_links(parent) if parent else None + attributes_with_defaults: Dict[str, types.AttributeValue] = {} + for key, value in (attributes or {}).items(): + attributes_with_defaults[key] = value + instrumentation_attributes = self.instrumentation_builder_config.build_attributes( + operation, destination, extraAttributes + ) + for key, value in instrumentation_attributes.items(): + attributes_with_defaults[key] = value + with self.tracer.start_as_current_span( + span_name, + context, + span_kind, + attributes_with_defaults, + links, + start_time, + record_exception, + set_status_on_exception, + end_on_exit, + ) as span: + yield span diff --git a/agent_dhal/agentdhal_core/_telemetry/_tracing_config.py b/agent_dhal/agentdhal_core/_telemetry/_tracing_config.py new file mode 100644 index 0000000..10e54b4 --- /dev/null +++ b/agent_dhal/agentdhal_core/_telemetry/_tracing_config.py @@ -0,0 +1,201 @@ +import logging +from abc import ABC, abstractmethod +from typing import Dict, Generic, List, Literal, TypedDict, TypeVar, Union + +from opentelemetry.trace import SpanKind +from opentelemetry.util import types +from typing_extensions import NotRequired + +from .._agent_id import AgentId +from .._topic import TopicId +from ._constants import NAMESPACE + +logger = logging.getLogger("agentdhal_core") +event_logger = logging.getLogger("agentdhal_core.events") + +Operation = TypeVar("Operation", bound=str) +Destination = TypeVar("Destination") +ExtraAttributes = TypeVar("ExtraAttributes") + + +class TracingConfig(ABC, Generic[Operation, Destination, ExtraAttributes]): + """ + A protocol that defines the configuration for instrumentation. + + This protocol specifies the required properties and methods that any + instrumentation configuration class must implement. It includes a + property to get the name of the module being instrumented and a method + to build attributes for the instrumentation configuration. + """ + + @property + @abstractmethod + def name(self) -> str: + """ + Returns: + The name of the module that is being instrumented. + """ + ... + + @abstractmethod + def build_attributes( + self, + operation: Operation, + destination: Destination, + extraAttributes: ExtraAttributes | None, + ) -> Dict[str, types.AttributeValue]: + """ + Builds the attributes for the instrumentation configuration. + + Returns: + Dict[str, str]: The attributes for the instrumentation configuration. + """ + ... + + @abstractmethod + def get_span_name( + self, + operation: Operation, + destination: Destination, + ) -> str: + """ + Returns the span name based on the given operation and destination. + + Parameters: + operation (MessagingOperation): The messaging operation. + destination (Optional[MessagingDestination]): The messaging destination. + + Returns: + str: The span name. + """ + ... + + @abstractmethod + def get_span_kind( + self, + operation: Operation, + ) -> SpanKind: + """ + Determines the span kind based on the given messaging operation. + + Parameters: + operation (MessagingOperation): The messaging operation. + + Returns: + SpanKind: The span kind based on the messaging operation. + """ + + +class ExtraMessageRuntimeAttributes(TypedDict): + message_size: NotRequired[int] + message_type: NotRequired[str] + + +MessagingDestination = Union[AgentId, TopicId, str, None] +MessagingOperation = Literal["create", "send", "publish", "receive", "intercept", "process", "ack"] + + +class MessageRuntimeTracingConfig( + TracingConfig[MessagingOperation, MessagingDestination, ExtraMessageRuntimeAttributes] +): + """ + A class that defines the configuration for message runtime instrumentation. + + This class implements the TracingConfig protocol and provides + the name of the module being instrumented and the attributes for the + instrumentation configuration. + """ + + def __init__(self, runtime_name: str) -> None: + self._runtime_name = runtime_name + + @property + def name(self) -> str: + return self._runtime_name + + def build_attributes( + self, + operation: MessagingOperation, + destination: MessagingDestination, + extraAttributes: ExtraMessageRuntimeAttributes | None, + ) -> Dict[str, types.AttributeValue]: + attrs: Dict[str, types.AttributeValue] = { + "messaging.operation": self._get_operation_type(operation), + "messaging.destination": self._get_destination_str(destination), + } + if extraAttributes: + # TODO: Make this more pythonic? + if "message_size" in extraAttributes: + attrs["messaging.message.envelope.size"] = extraAttributes["message_size"] + if "message_type" in extraAttributes: + attrs["messaging.message.type"] = extraAttributes["message_type"] + return attrs + + def get_span_name( + self, + operation: MessagingOperation, + destination: MessagingDestination, + ) -> str: + """ + Returns the span name based on the given operation and destination. + Semantic Conventions - https://opentelemetry.io/docs/specs/semconv/messaging/messaging-spans/#span-name + + Parameters: + operation (MessagingOperation): The messaging operation. + destination (Optional[MessagingDestination]): The messaging destination. + + Returns: + str: The span name. + """ + span_parts: List[str] = [operation] + destination_str = self._get_destination_str(destination) + if destination_str: + span_parts.append(destination_str) + span_name = " ".join(span_parts) + return f"{NAMESPACE} {span_name}" + + def get_span_kind( + self, + operation: MessagingOperation, + ) -> SpanKind: + """ + Determines the span kind based on the given messaging operation. + Semantic Conventions - https://opentelemetry.io/docs/specs/semconv/messaging/messaging-spans/#span-kind + + Parameters: + operation (MessagingOperation): The messaging operation. + + Returns: + SpanKind: The span kind based on the messaging operation. + """ + if operation in ["create", "send", "publish"]: + return SpanKind.PRODUCER + elif operation in ["receive", "intercept", "process", "ack"]: + return SpanKind.CONSUMER + else: + return SpanKind.CLIENT + + # TODO: Use stringified convention + def _get_destination_str(self, destination: MessagingDestination) -> str: + if isinstance(destination, AgentId): + return f"{destination.type}.({destination.key})-A" + elif isinstance(destination, TopicId): + return f"{destination.type}.({destination.source})-T" + elif isinstance(destination, str): + return destination + elif destination is None: + return "" + else: + raise ValueError(f"Unknown destination type: {type(destination)}") + + def _get_operation_type(self, operation: MessagingOperation) -> str: + if operation in ["send", "publish"]: + return "publish" + if operation in ["create"]: + return "create" + elif operation in ["receive", "intercept", "ack"]: + return "receive" + elif operation in ["process"]: + return "process" + else: + return "Unknown" diff --git a/agent_dhal/agentdhal_core/_topic.py b/agent_dhal/agentdhal_core/_topic.py new file mode 100644 index 0000000..67d4a24 --- /dev/null +++ b/agent_dhal/agentdhal_core/_topic.py @@ -0,0 +1,47 @@ +import re +from dataclasses import dataclass + +from typing_extensions import Self + + +def is_valid_topic_type(value: str) -> bool: + return bool(re.match(r"^[\w\-\.\:\=]+\Z", value)) + + +@dataclass(eq=True, frozen=True) +class TopicId: + """ + TopicId defines the scope of a broadcast message. In essence, agent runtime implements a publish-subscribe model through its broadcast API: when publishing a message, the topic must be specified. + + See here for more information: :ref:`topic_and_subscription_topic` + """ + + type: str + """Type of the event that this topic_id contains. Adhere's to the cloud event spec. + + Must match the pattern: ^[\\w\\-\\.\\:\\=]+\\Z + + Learn more here: https://github.com/cloudevents/spec/blob/main/cloudevents/spec.md#type + """ + + source: str + """Identifies the context in which an event happened. Adhere's to the cloud event spec. + + Learn more here: https://github.com/cloudevents/spec/blob/main/cloudevents/spec.md#source-1 + """ + + def __post_init__(self) -> None: + if is_valid_topic_type(self.type) is False: + raise ValueError(f"Invalid topic type: {self.type}. Must match the pattern: ^[\\w\\-\\.\\:\\=]+\\Z") + + def __str__(self) -> str: + return f"{self.type}/{self.source}" + + @classmethod + def from_str(cls, topic_id: str) -> Self: + """Convert a string of the format ``type/source`` into a TopicId""" + items = topic_id.split("/", maxsplit=1) + if len(items) != 2: + raise ValueError(f"Invalid topic id: {topic_id}") + type, source = items[0], items[1] + return cls(type, source) diff --git a/agent_dhal/agentdhal_core/_type_helpers.py b/agent_dhal/agentdhal_core/_type_helpers.py new file mode 100644 index 0000000..66e52e4 --- /dev/null +++ b/agent_dhal/agentdhal_core/_type_helpers.py @@ -0,0 +1,33 @@ +from collections.abc import Sequence +from types import NoneType, UnionType +from typing import Any, Optional, Type, Union, get_args, get_origin + + +def is_union(t: object) -> bool: + origin = get_origin(t) + return origin is Union or origin is UnionType + + +def is_optional(t: object) -> bool: + origin = get_origin(t) + return origin is Optional + + +# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any +class AnyType: + pass + + +def get_types(t: object) -> Sequence[Type[Any]] | None: + if is_union(t): + return get_args(t) + elif is_optional(t): + return tuple(list(get_args(t)) + [NoneType]) + elif t is Any: + return (AnyType,) + elif isinstance(t, type): + return (t,) + elif isinstance(t, NoneType): + return (NoneType,) + else: + return None diff --git a/agent_dhal/agentdhal_core/_type_prefix_subscription.py b/agent_dhal/agentdhal_core/_type_prefix_subscription.py new file mode 100644 index 0000000..49fc88d --- /dev/null +++ b/agent_dhal/agentdhal_core/_type_prefix_subscription.py @@ -0,0 +1,69 @@ +import uuid + +from ._agent_id import AgentId +from ._agent_type import AgentType +from ._subscription import Subscription +from ._topic import TopicId +from .exceptions import CantHandleException + + +class TypePrefixSubscription(Subscription): + """This subscription matches on topics based on a prefix of the type and maps to agents using the source of the topic as the agent key. + + This subscription causes each source to have its own agent instance. + + Example: + + .. code-block:: python + + from agentdhal_core import TypePrefixSubscription + + subscription = TypePrefixSubscription(topic_type_prefix="t1", agent_type="a1") + + In this case: + + - A topic_id with type `t1` and source `s1` will be handled by an agent of type `a1` with key `s1` + - A topic_id with type `t1` and source `s2` will be handled by an agent of type `a1` with key `s2`. + - A topic_id with type `t1SUFFIX` and source `s2` will be handled by an agent of type `a1` with key `s2`. + + Args: + topic_type_prefix (str): Topic type prefix to match against + agent_type (str): Agent type to handle this subscription + """ + + def __init__(self, topic_type_prefix: str, agent_type: str | AgentType, id: str | None = None): + self._topic_type_prefix = topic_type_prefix + if isinstance(agent_type, AgentType): + self._agent_type = agent_type.type + else: + self._agent_type = agent_type + self._id = id or str(uuid.uuid4()) + + @property + def id(self) -> str: + return self._id + + @property + def topic_type_prefix(self) -> str: + return self._topic_type_prefix + + @property + def agent_type(self) -> str: + return self._agent_type + + def is_match(self, topic_id: TopicId) -> bool: + return topic_id.type.startswith(self._topic_type_prefix) + + def map_to_agent(self, topic_id: TopicId) -> AgentId: + if not self.is_match(topic_id): + raise CantHandleException("TopicId does not match the subscription") + + return AgentId(type=self._agent_type, key=topic_id.source) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TypePrefixSubscription): + return False + + return self.id == other.id or ( + self.agent_type == other.agent_type and self.topic_type_prefix == other.topic_type_prefix + ) diff --git a/agent_dhal/agentdhal_core/_type_subscription.py b/agent_dhal/agentdhal_core/_type_subscription.py new file mode 100644 index 0000000..dbb00d3 --- /dev/null +++ b/agent_dhal/agentdhal_core/_type_subscription.py @@ -0,0 +1,66 @@ +import uuid + +from ._agent_id import AgentId +from ._agent_type import AgentType +from ._subscription import Subscription +from ._topic import TopicId +from .exceptions import CantHandleException + + +class TypeSubscription(Subscription): + """This subscription matches on topics based on the type and maps to agents using the source of the topic as the agent key. + + This subscription causes each source to have its own agent instance. + + Example: + + .. code-block:: python + + from agentdhal_core import TypeSubscription + + subscription = TypeSubscription(topic_type="t1", agent_type="a1") + + In this case: + + - A topic_id with type `t1` and source `s1` will be handled by an agent of type `a1` with key `s1` + - A topic_id with type `t1` and source `s2` will be handled by an agent of type `a1` with key `s2`. + + Args: + topic_type (str): Topic type to match against + agent_type (str): Agent type to handle this subscription + """ + + def __init__(self, topic_type: str, agent_type: str | AgentType, id: str | None = None): + self._topic_type = topic_type + if isinstance(agent_type, AgentType): + self._agent_type = agent_type.type + else: + self._agent_type = agent_type + self._id = id or str(uuid.uuid4()) + + @property + def id(self) -> str: + return self._id + + @property + def topic_type(self) -> str: + return self._topic_type + + @property + def agent_type(self) -> str: + return self._agent_type + + def is_match(self, topic_id: TopicId) -> bool: + return topic_id.type == self._topic_type + + def map_to_agent(self, topic_id: TopicId) -> AgentId: + if not self.is_match(topic_id): + raise CantHandleException("TopicId does not match the subscription") + + return AgentId(type=self._agent_type, key=topic_id.source) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TypeSubscription): + return False + + return self.id == other.id or (self.agent_type == other.agent_type and self.topic_type == other.topic_type) diff --git a/agent_dhal/agentdhal_core/_types.py b/agent_dhal/agentdhal_core/_types.py new file mode 100644 index 0000000..5e3850f --- /dev/null +++ b/agent_dhal/agentdhal_core/_types.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class FunctionCall: + id: str + # JSON args + arguments: str + # Function to call + name: str diff --git a/agent_dhal/agentdhal_core/code_executor/__init__.py b/agent_dhal/agentdhal_core/code_executor/__init__.py new file mode 100644 index 0000000..f1789a5 --- /dev/null +++ b/agent_dhal/agentdhal_core/code_executor/__init__.py @@ -0,0 +1,21 @@ +from ._base import CodeBlock, CodeExecutor, CodeResult +from ._func_with_reqs import ( + Alias, + FunctionWithRequirements, + FunctionWithRequirementsStr, + Import, + ImportFromModule, + with_requirements, +) + +__all__ = [ + "CodeBlock", + "CodeExecutor", + "CodeResult", + "Alias", + "ImportFromModule", + "Import", + "FunctionWithRequirements", + "FunctionWithRequirementsStr", + "with_requirements", +] diff --git a/agent_dhal/agentdhal_core/code_executor/_base.py b/agent_dhal/agentdhal_core/code_executor/_base.py new file mode 100644 index 0000000..16be434 --- /dev/null +++ b/agent_dhal/agentdhal_core/code_executor/_base.py @@ -0,0 +1,102 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/base.py +# Credit to original authors + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from types import TracebackType +from typing import List, Optional, Type + +from pydantic import BaseModel +from typing_extensions import Self + +from .._cancellation_token import CancellationToken +from .._component_config import ComponentBase + + +@dataclass +class CodeBlock: + """A code block extracted fromm an agent message.""" + + code: str + language: str + + +@dataclass +class CodeResult: + """Result of a code execution.""" + + exit_code: int + output: str + + +class CodeExecutor(ABC, ComponentBase[BaseModel]): + """Executes code blocks and returns the result. + + This is an abstract base class for code executors. It defines the interface + for executing code blocks and returning the result. A concrete implementation + of this class should be provided to execute code blocks in a specific + environment. For example, :class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor` executes + code blocks in a command line environment in a Docker container. + + It is recommended for subclass to be used as a context manager to ensure + that resources are cleaned up properly. To do this, implement the + :meth:`~agentdhal_core.code_executor.CodeExecutor.start` and + :meth:`~agentdhal_core.code_executor.CodeExecutor.stop` methods + that will be called when entering and exiting the context manager. + + """ + + component_type = "code_executor" + + @abstractmethod + async def execute_code_blocks( + self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken + ) -> CodeResult: + """Execute code blocks and return the result. + + This method should be implemented by the code executor. + + Args: + code_blocks (List[CodeBlock]): The code blocks to execute. + + Returns: + CodeResult: The result of the code execution. + + Raises: + ValueError: Errors in user inputs + asyncio.TimeoutError: Code execution timeouts + asyncio.CancelledError: CancellationToken evoked during execution + """ + ... + + @abstractmethod + async def start(self) -> None: + """Start the code executor.""" + ... + + @abstractmethod + async def stop(self) -> None: + """Stop the code executor and release any resources.""" + ... + + @abstractmethod + async def restart(self) -> None: + """Restart the code executor. + + This method should be implemented by the code executor. + + This method is called when the agent is reset. + """ + ... + + async def __aenter__(self) -> Self: + await self.start() + return self + + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> Optional[bool]: + await self.stop() + return None diff --git a/agent_dhal/agentdhal_core/code_executor/_func_with_reqs.py b/agent_dhal/agentdhal_core/code_executor/_func_with_reqs.py new file mode 100644 index 0000000..da8c647 --- /dev/null +++ b/agent_dhal/agentdhal_core/code_executor/_func_with_reqs.py @@ -0,0 +1,277 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/func_with_reqs.py +# Credit to original authors + +from __future__ import annotations + +import functools +import inspect +from dataclasses import dataclass, field +from importlib.abc import SourceLoader +from importlib.util import module_from_spec, spec_from_loader +from textwrap import dedent, indent +from typing import Any, Callable, Generic, List, Sequence, Set, Tuple, TypeVar, Union + +from typing_extensions import ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") + + +def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str: + if isinstance(func, FunctionWithRequirementsStr): + return func.func + + if isinstance(func, FunctionWithRequirements): + code = inspect.getsource(func.func) + else: + code = inspect.getsource(func) + # Strip the decorator + if code.startswith("@"): + code = code[code.index("\n") + 1 :] + return code + + +@dataclass(frozen=True) +class Alias: + name: str + alias: str + + +@dataclass(frozen=True) +class ImportFromModule: + module: str + imports: Tuple[Union[str, Alias], ...] + + # backward compatibility + def __init__( + self, + module: str, + imports: Union[Tuple[Union[str, Alias], ...], List[Union[str, Alias]]], + ): + object.__setattr__(self, "module", module) + if isinstance(imports, list): + object.__setattr__(self, "imports", tuple(imports)) + else: + object.__setattr__(self, "imports", imports) + + +Import = Union[str, ImportFromModule, Alias] + + +def _import_to_str(im: Import) -> str: + if isinstance(im, str): + return f"import {im}" + elif isinstance(im, Alias): + return f"import {im.name} as {im.alias}" + else: + + def to_str(i: Union[str, Alias]) -> str: + if isinstance(i, str): + return i + else: + return f"{i.name} as {i.alias}" + + imports = ", ".join(map(to_str, im.imports)) + return f"from {im.module} import {imports}" + + +class _StringLoader(SourceLoader): + def __init__(self, data: str): + self.data = data + + def get_source(self, fullname: str) -> str: + return self.data + + def get_data(self, path: str) -> bytes: + return self.data.encode("utf-8") + + def get_filename(self, fullname: str) -> str: + return "/" + fullname + ".py" + + +@dataclass +class FunctionWithRequirementsStr: + func: str + compiled_func: Callable[..., Any] + _func_name: str + python_packages: Sequence[str] = field(default_factory=list) + global_imports: Sequence[Import] = field(default_factory=list) + + def __init__(self, func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []): + self.func = func + self.python_packages = python_packages + self.global_imports = global_imports + + module_name = "func_module" + loader = _StringLoader(func) + spec = spec_from_loader(module_name, loader) + if spec is None: + raise ValueError("Could not create spec") + module = module_from_spec(spec) + if spec.loader is None: + raise ValueError("Could not create loader") + + try: + spec.loader.exec_module(module) + except Exception as e: + raise ValueError(f"Could not compile function: {e}") from e + + functions = inspect.getmembers(module, inspect.isfunction) + if len(functions) != 1: + raise ValueError("The string must contain exactly one function") + + self._func_name, self.compiled_func = functions[0] + + def __call__(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("String based function with requirement objects are not directly callable") + + +@dataclass +class FunctionWithRequirements(Generic[T, P]): + func: Callable[P, T] + python_packages: Sequence[str] = field(default_factory=list) + global_imports: Sequence[Import] = field(default_factory=list) + + @classmethod + def from_callable( + cls, func: Callable[P, T], python_packages: Sequence[str] = [], global_imports: Sequence[Import] = [] + ) -> FunctionWithRequirements[T, P]: + return cls(python_packages=python_packages, global_imports=global_imports, func=func) + + @staticmethod + def from_str( + func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = [] + ) -> FunctionWithRequirementsStr: + return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports) + + # Type this based on F + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + return self.func(*args, **kwargs) + + +def with_requirements( + python_packages: Sequence[str] = [], global_imports: Sequence[Import] = [] +) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: + """ + Decorate a function with package and import requirements for code execution environments. + + This decorator makes a function available for reference in dynamically executed code blocks + by wrapping it in a `FunctionWithRequirements` object that tracks its dependencies. When the + decorated function is passed to a code executor, it can be imported by name in the executed + code, with all dependencies automatically handled. + + Args: + python_packages (Sequence[str], optional): Python packages required by the function. + Can include version specifications (e.g., ["pandas>=1.0.0"]). Defaults to []. + global_imports (Sequence[Import], optional): Import statements required by the function. + Can be strings ("numpy"), ImportFromModule objects, or Alias objects. Defaults to []. + + Returns: + Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: A decorator that wraps + the target function, preserving its functionality while registering its dependencies. + + Example: + + .. code-block:: python + + import tempfile + import asyncio + from agentdhal_core import CancellationToken + from agentdhal_core.code_executor import with_requirements, CodeBlock + from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor + import pandas + + @with_requirements(python_packages=["pandas"], global_imports=["pandas"]) + def load_data() -> pandas.DataFrame: + \"\"\"Load some sample data. + + Returns: + pandas.DataFrame: A DataFrame with sample data + \"\"\" + data = { + "name": ["John", "Anna", "Peter", "Linda"], + "location": ["New York", "Paris", "Berlin", "London"], + "age": [24, 13, 53, 33], + } + return pandas.DataFrame(data) + + async def run_example(): + # The decorated function can be used in executed code + with tempfile.TemporaryDirectory() as temp_dir: + executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[load_data]) + code = f\"\"\"from {executor.functions_module} import load_data + + # Use the imported function + data = load_data() + print(data['name'][0])\"\"\" + + result = await executor.execute_code_blocks( + code_blocks=[CodeBlock(language="python", code=code)], + cancellation_token=CancellationToken(), + ) + print(result.output) # Output: John + + # Run the async example + asyncio.run(run_example()) + """ + + def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]: + func_with_reqs = FunctionWithRequirements( + python_packages=python_packages, global_imports=global_imports, func=func + ) + + functools.update_wrapper(func_with_reqs, func) + return func_with_reqs + + return wrapper + + +def build_python_functions_file( + funcs: Sequence[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]], +) -> str: + """:meta private:""" + # First collect all global imports + global_imports: Set[Import] = set() + for func in funcs: + if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)): + global_imports.update(func.global_imports) + + content = "\n".join(map(_import_to_str, global_imports)) + "\n\n" + + for func in funcs: + content += _to_code(func) + "\n\n" + + return content + + +def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str: + """Generate a stub for a function as a string + + Args: + func (Callable[..., Any]): The function to generate a stub for + + Returns: + str: The stub for the function + """ + if isinstance(func, FunctionWithRequirementsStr): + return to_stub(func.compiled_func) + + content = f"def {func.__name__}{inspect.signature(func)}:\n" + docstring = func.__doc__ + + if docstring: + docstring = dedent(docstring) + docstring = '"""' + docstring + '"""' + docstring = indent(docstring, " ") + content += docstring + "\n" + + content += " ..." + return content + + +def to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str: + return _to_code(func) + + +def import_to_str(im: Import) -> str: + return _import_to_str(im) diff --git a/agent_dhal/agentdhal_core/exceptions.py b/agent_dhal/agentdhal_core/exceptions.py new file mode 100644 index 0000000..3f4d76d --- /dev/null +++ b/agent_dhal/agentdhal_core/exceptions.py @@ -0,0 +1,17 @@ +__all__ = ["CantHandleException", "UndeliverableException", "MessageDroppedException", "NotAccessibleError"] + + +class CantHandleException(Exception): + """Raised when a handler can't handle the exception.""" + + +class UndeliverableException(Exception): + """Raised when a message can't be delivered.""" + + +class MessageDroppedException(Exception): + """Raised when a message is dropped.""" + + +class NotAccessibleError(Exception): + """Tried to access a value that is not accessible. For example if it is remote cannot be accessed locally.""" diff --git a/agent_dhal/agentdhal_core/logging.py b/agent_dhal/agentdhal_core/logging.py new file mode 100644 index 0000000..3b333ce --- /dev/null +++ b/agent_dhal/agentdhal_core/logging.py @@ -0,0 +1,294 @@ +import json +from enum import Enum +from typing import Any, Dict, List, cast + +from ._agent_id import AgentId +from ._message_handler_context import MessageHandlerContext +from ._topic import TopicId + + +class LLMCallEvent: + def __init__( + self, + *, + messages: List[Dict[str, Any]], + response: Dict[str, Any], + prompt_tokens: int, + completion_tokens: int, + **kwargs: Any, + ) -> None: + """To be used by model clients to log the call to the LLM. + + Args: + messages (List[Dict[str, Any]]): The messages used in the call. Must be json serializable. + response (Dict[str, Any]): The response of the call. Must be json serializable. + prompt_tokens (int): Number of tokens used in the prompt. + completion_tokens (int): Number of tokens used in the completion. + + Example: + + .. code-block:: python + + import logging + from agentdhal_core import EVENT_LOGGER_NAME + from agentdhal_core.logging import LLMCallEvent + + response = {"content": "Hello, world!"} + messages = [{"role": "user", "content": "Hello, world!"}] + logger = logging.getLogger(EVENT_LOGGER_NAME) + logger.info(LLMCallEvent(prompt_tokens=10, completion_tokens=20, response=response, messages=messages)) + + """ + self.kwargs = kwargs + self.kwargs["type"] = "LLMCall" + self.kwargs["messages"] = messages + self.kwargs["response"] = response + self.kwargs["prompt_tokens"] = prompt_tokens + self.kwargs["completion_tokens"] = completion_tokens + try: + agent_id = MessageHandlerContext.agent_id() + except RuntimeError: + agent_id = None + self.kwargs["agent_id"] = None if agent_id is None else str(agent_id) + + @property + def prompt_tokens(self) -> int: + return cast(int, self.kwargs["prompt_tokens"]) + + @property + def completion_tokens(self) -> int: + return cast(int, self.kwargs["completion_tokens"]) + + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) + + +class LLMStreamStartEvent: + """To be used by model clients to log the start of a stream. + + Args: + messages (List[Dict[str, Any]]): The messages used in the call. Must be json serializable. + + Example: + + .. code-block:: python + + import logging + from agentdhal_core import EVENT_LOGGER_NAME + from agentdhal_core.logging import LLMStreamStartEvent + + messages = [{"role": "user", "content": "Hello, world!"}] + logger = logging.getLogger(EVENT_LOGGER_NAME) + logger.info(LLMStreamStartEvent(messages=messages)) + + """ + + def __init__( + self, + *, + messages: List[Dict[str, Any]], + **kwargs: Any, + ) -> None: + self.kwargs = kwargs + self.kwargs["type"] = "LLMStreamStart" + self.kwargs["messages"] = messages + try: + agent_id = MessageHandlerContext.agent_id() + except RuntimeError: + agent_id = None + self.kwargs["agent_id"] = None if agent_id is None else str(agent_id) + + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) + + +class LLMStreamEndEvent: + def __init__( + self, + *, + response: Dict[str, Any], + prompt_tokens: int, + completion_tokens: int, + **kwargs: Any, + ) -> None: + """To be used by model clients to log the end of a stream. + + Args: + response (Dict[str, Any]): The response of the call. Must be json serializable. + prompt_tokens (int): Number of tokens used in the prompt. + completion_tokens (int): Number of tokens used in the completion. + + Example: + + .. code-block:: python + + import logging + from agentdhal_core import EVENT_LOGGER_NAME + from agentdhal_core.logging import LLMStreamEndEvent + + response = {"content": "Hello, world!"} + logger = logging.getLogger(EVENT_LOGGER_NAME) + logger.info(LLMStreamEndEvent(prompt_tokens=10, completion_tokens=20, response=response)) + + """ + self.kwargs = kwargs + self.kwargs["type"] = "LLMStreamEnd" + self.kwargs["response"] = response + self.kwargs["prompt_tokens"] = prompt_tokens + self.kwargs["completion_tokens"] = completion_tokens + try: + agent_id = MessageHandlerContext.agent_id() + except RuntimeError: + agent_id = None + self.kwargs["agent_id"] = None if agent_id is None else str(agent_id) + + @property + def prompt_tokens(self) -> int: + return cast(int, self.kwargs["prompt_tokens"]) + + @property + def completion_tokens(self) -> int: + return cast(int, self.kwargs["completion_tokens"]) + + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) + + +class ToolCallEvent: + def __init__( + self, + *, + tool_name: str, + arguments: Dict[str, Any], + result: str, + ) -> None: + """Used by subclasses of :class:`~agentdhal_core.tools.BaseTool` to log executions of tools. + + Args: + tool_name (str): The name of the tool. + arguments (Dict[str, Any]): The arguments of the tool. Must be json serializable. + result (str): The result of the tool. Must be a string. + + Example: + + .. code-block:: python + + from agentdhal_core import EVENT_LOGGER_NAME + from agentdhal_core.logging import ToolCallEvent + + logger = logging.getLogger(EVENT_LOGGER_NAME) + logger.info(ToolCallEvent(tool_name="Tool1", call_id="123", arguments={"arg1": "value1"})) + + """ + self.kwargs: Dict[str, Any] = {} + self.kwargs["type"] = "ToolCall" + self.kwargs["tool_name"] = tool_name + self.kwargs["arguments"] = arguments + self.kwargs["result"] = result + try: + agent_id = MessageHandlerContext.agent_id() + except RuntimeError: + agent_id = None + self.kwargs["agent_id"] = None if agent_id is None else str(agent_id) + + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) + + +class MessageKind(Enum): + DIRECT = 1 + PUBLISH = 2 + RESPOND = 3 + + +class DeliveryStage(Enum): + SEND = 1 + DELIVER = 2 + + +class MessageEvent: + def __init__( + self, + *, + payload: str, + sender: AgentId | None, + receiver: AgentId | TopicId | None, + kind: MessageKind, + delivery_stage: DeliveryStage, + **kwargs: Any, + ) -> None: + self.kwargs = kwargs + self.kwargs["payload"] = payload + self.kwargs["sender"] = None if sender is None else str(sender) + self.kwargs["receiver"] = None if receiver is None else str(receiver) + self.kwargs["kind"] = str(kind) + self.kwargs["delivery_stage"] = str(delivery_stage) + self.kwargs["type"] = "Message" + + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) + + +class MessageDroppedEvent: + def __init__( + self, + *, + payload: str, + sender: AgentId | None, + receiver: AgentId | TopicId | None, + kind: MessageKind, + **kwargs: Any, + ) -> None: + self.kwargs = kwargs + self.kwargs["payload"] = payload + self.kwargs["sender"] = None if sender is None else str(sender) + self.kwargs["receiver"] = None if receiver is None else str(receiver) + self.kwargs["kind"] = str(kind) + self.kwargs["type"] = "MessageDropped" + + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) + + +class MessageHandlerExceptionEvent: + def __init__( + self, + *, + payload: str, + handling_agent: AgentId, + exception: BaseException, + **kwargs: Any, + ) -> None: + self.kwargs = kwargs + self.kwargs["payload"] = payload + self.kwargs["handling_agent"] = str(handling_agent) + self.kwargs["exception"] = str(exception) + self.kwargs["type"] = "MessageHandlerException" + + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) + + +class AgentConstructionExceptionEvent: + def __init__( + self, + *, + agent_id: AgentId, + exception: BaseException, + **kwargs: Any, + ) -> None: + self.kwargs = kwargs + self.kwargs["agent_id"] = str(agent_id) + self.kwargs["exception"] = str(exception) + self.kwargs["type"] = "AgentConstructionException" + + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) diff --git a/agent_dhal/agentdhal_core/memory/__init__.py b/agent_dhal/agentdhal_core/memory/__init__.py new file mode 100644 index 0000000..69a20f2 --- /dev/null +++ b/agent_dhal/agentdhal_core/memory/__init__.py @@ -0,0 +1,11 @@ +from ._base_memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult +from ._list_memory import ListMemory + +__all__ = [ + "Memory", + "MemoryContent", + "MemoryQueryResult", + "UpdateContextResult", + "MemoryMimeType", + "ListMemory", +] diff --git a/agent_dhal/agentdhal_core/memory/_base_memory.py b/agent_dhal/agentdhal_core/memory/_base_memory.py new file mode 100644 index 0000000..c6f5563 --- /dev/null +++ b/agent_dhal/agentdhal_core/memory/_base_memory.py @@ -0,0 +1,132 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Union + +from pydantic import BaseModel, ConfigDict, field_serializer + +from .._cancellation_token import CancellationToken +from .._component_config import ComponentBase +from .._image import Image +from ..model_context import ChatCompletionContext + + +class MemoryMimeType(Enum): + """Supported MIME types for memory content.""" + + TEXT = "text/plain" + JSON = "application/json" + MARKDOWN = "text/markdown" + IMAGE = "image/*" + BINARY = "application/octet-stream" + + +ContentType = Union[str, bytes, Dict[str, Any], Image] + + +class MemoryContent(BaseModel): + """A memory content item.""" + + content: ContentType + """The content of the memory item. It can be a string, bytes, dict, or :class:`~agentdhal_core.Image`.""" + + mime_type: MemoryMimeType | str + """The MIME type of the memory content.""" + + metadata: Dict[str, Any] | None = None + """Metadata associated with the memory item.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_serializer("mime_type") + def serialize_mime_type(self, mime_type: MemoryMimeType | str) -> str: + """Serialize the MIME type to a string.""" + if isinstance(mime_type, MemoryMimeType): + return mime_type.value + return mime_type + + +class MemoryQueryResult(BaseModel): + """Result of a memory :meth:`~agentdhal_core.memory.Memory.query` operation.""" + + results: List[MemoryContent] + + +class UpdateContextResult(BaseModel): + """Result of a memory :meth:`~agentdhal_core.memory.Memory.update_context` operation.""" + + memories: MemoryQueryResult + + +class Memory(ABC, ComponentBase[BaseModel]): + """Protocol defining the interface for memory implementations. + + A memory is the storage for data that can be used to enrich or modify the model context. + + A memory implementation can use any storage mechanism, such as a list, a database, or a file system. + It can also use any retrieval mechanism, such as vector search or text search. + It is up to the implementation to decide how to store and retrieve data. + + It is also a memory implementation's responsibility to update the model context + with relevant memory content based on the current model context and querying the memory store. + + See :class:`~agentdhal_core.memory.ListMemory` for an example implementation. + """ + + component_type = "memory" + + @abstractmethod + async def update_context( + self, + model_context: ChatCompletionContext, + ) -> UpdateContextResult: + """ + Update the provided model context using relevant memory content. + + Args: + model_context: The context to update. + + Returns: + UpdateContextResult containing relevant memories + """ + ... + + @abstractmethod + async def query( + self, + query: str | MemoryContent, + cancellation_token: CancellationToken | None = None, + **kwargs: Any, + ) -> MemoryQueryResult: + """ + Query the memory store and return relevant entries. + + Args: + query: Query content item + cancellation_token: Optional token to cancel operation + **kwargs: Additional implementation-specific parameters + + Returns: + MemoryQueryResult containing memory entries with relevance scores + """ + ... + + @abstractmethod + async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: + """ + Add a new content to memory. + + Args: + content: The memory content to add + cancellation_token: Optional token to cancel operation + """ + ... + + @abstractmethod + async def clear(self) -> None: + """Clear all entries from memory.""" + ... + + @abstractmethod + async def close(self) -> None: + """Clean up any resources used by the memory implementation.""" + ... diff --git a/agent_dhal/agentdhal_core/memory/_list_memory.py b/agent_dhal/agentdhal_core/memory/_list_memory.py new file mode 100644 index 0000000..cc7cdc7 --- /dev/null +++ b/agent_dhal/agentdhal_core/memory/_list_memory.py @@ -0,0 +1,172 @@ +from typing import Any, List + +from pydantic import BaseModel, Field +from typing_extensions import Self + +from .._cancellation_token import CancellationToken +from .._component_config import Component +from ..model_context import ChatCompletionContext +from ..models import SystemMessage +from ._base_memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult + + +class ListMemoryConfig(BaseModel): + """Configuration for ListMemory component.""" + + name: str | None = None + """Optional identifier for this memory instance.""" + memory_contents: List[MemoryContent] = Field(default_factory=list) + """List of memory contents stored in this memory instance.""" + + +class ListMemory(Memory, Component[ListMemoryConfig]): + """Simple chronological list-based memory implementation. + + This memory implementation stores contents in a list and retrieves them in + chronological order. It has an `update_context` method that updates model contexts + by appending all stored memories. + + The memory content can be directly accessed and modified through the content property, + allowing external applications to manage memory contents directly. + + Example: + + .. code-block:: python + + import asyncio + from agentdhal_core.memory import ListMemory, MemoryContent + from agentdhal_core.model_context import BufferedChatCompletionContext + + + async def main() -> None: + # Initialize memory + memory = ListMemory(name="chat_history") + + # Add memory content + content = MemoryContent(content="User prefers formal language", mime_type="text/plain") + await memory.add(content) + + # Directly modify memory contents + memory.content = [MemoryContent(content="New preference", mime_type="text/plain")] + + # Create a model context + model_context = BufferedChatCompletionContext(buffer_size=10) + + # Update a model context with memory + await memory.update_context(model_context) + + # See the updated model context + print(await model_context.get_messages()) + + + asyncio.run(main()) + + Args: + name: Optional identifier for this memory instance + + """ + + component_type = "memory" + component_provider_override = "agentdhal_core.memory.ListMemory" + component_config_schema = ListMemoryConfig + + def __init__(self, name: str | None = None, memory_contents: List[MemoryContent] | None = None) -> None: + self._name = name or "default_list_memory" + self._contents: List[MemoryContent] = memory_contents if memory_contents is not None else [] + + @property + def name(self) -> str: + """Get the memory instance identifier. + + Returns: + str: Memory instance name + """ + return self._name + + @property + def content(self) -> List[MemoryContent]: + """Get the current memory contents. + + Returns: + List[MemoryContent]: List of stored memory contents + """ + return self._contents + + @content.setter + def content(self, value: List[MemoryContent]) -> None: + """Set the memory contents. + + Args: + value: New list of memory contents to store + """ + self._contents = value + + async def update_context( + self, + model_context: ChatCompletionContext, + ) -> UpdateContextResult: + """Update the model context by appending memory content. + + This method mutates the provided model_context by adding all memories as a + SystemMessage. + + Args: + model_context: The context to update. Will be mutated if memories exist. + + Returns: + UpdateContextResult containing the memories that were added to the context + """ + + if not self._contents: + return UpdateContextResult(memories=MemoryQueryResult(results=[])) + + memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(self._contents, 1)] + + if memory_strings: + memory_context = "\nRelevant memory content (in chronological order):\n" + "\n".join(memory_strings) + "\n" + await model_context.add_message(SystemMessage(content=memory_context)) + + return UpdateContextResult(memories=MemoryQueryResult(results=self._contents)) + + async def query( + self, + query: str | MemoryContent = "", + cancellation_token: CancellationToken | None = None, + **kwargs: Any, + ) -> MemoryQueryResult: + """Return all memories without any filtering. + + Args: + query: Ignored in this implementation + cancellation_token: Optional token to cancel operation + **kwargs: Additional parameters (ignored) + + Returns: + MemoryQueryResult containing all stored memories + """ + _ = query, cancellation_token, kwargs + return MemoryQueryResult(results=self._contents) + + async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: + """Add new content to memory. + + Args: + content: Memory content to store + cancellation_token: Optional token to cancel operation + """ + self._contents.append(content) + + async def clear(self) -> None: + """Clear all memory content.""" + self._contents = [] + + async def close(self) -> None: + """Cleanup resources if needed.""" + pass + + @classmethod + def _from_config(cls, config: ListMemoryConfig) -> Self: + return cls(name=config.name, memory_contents=config.memory_contents) + + def _to_config(self) -> ListMemoryConfig: + return ListMemoryConfig(name=self.name, memory_contents=self._contents) diff --git a/agent_dhal/agentdhal_core/model_context/__init__.py b/agent_dhal/agentdhal_core/model_context/__init__.py new file mode 100644 index 0000000..b689861 --- /dev/null +++ b/agent_dhal/agentdhal_core/model_context/__init__.py @@ -0,0 +1,16 @@ +from ._buffered_chat_completion_context import BufferedChatCompletionContext +from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState +from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext +from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext +from ._unbounded_chat_completion_context import ( + UnboundedChatCompletionContext, +) + +__all__ = [ + "ChatCompletionContext", + "ChatCompletionContextState", + "UnboundedChatCompletionContext", + "BufferedChatCompletionContext", + "TokenLimitedChatCompletionContext", + "HeadAndTailChatCompletionContext", +] diff --git a/agent_dhal/agentdhal_core/model_context/_buffered_chat_completion_context.py b/agent_dhal/agentdhal_core/model_context/_buffered_chat_completion_context.py new file mode 100644 index 0000000..b6c0a5e --- /dev/null +++ b/agent_dhal/agentdhal_core/model_context/_buffered_chat_completion_context.py @@ -0,0 +1,50 @@ +from typing import List + +from pydantic import BaseModel +from typing_extensions import Self + +from .._component_config import Component +from ..models import FunctionExecutionResultMessage, LLMMessage +from ._chat_completion_context import ChatCompletionContext + + +class BufferedChatCompletionContextConfig(BaseModel): + buffer_size: int + initial_messages: List[LLMMessage] | None = None + + +class BufferedChatCompletionContext(ChatCompletionContext, Component[BufferedChatCompletionContextConfig]): + """A buffered chat completion context that keeps a view of the last n messages, + where n is the buffer size. The buffer size is set at initialization. + + Args: + buffer_size (int): The size of the buffer. + initial_messages (List[LLMMessage] | None): The initial messages. + """ + + component_config_schema = BufferedChatCompletionContextConfig + component_provider_override = "agentdhal_core.model_context.BufferedChatCompletionContext" + + def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None: + super().__init__(initial_messages) + if buffer_size <= 0: + raise ValueError("buffer_size must be greater than 0.") + self._buffer_size = buffer_size + + async def get_messages(self) -> List[LLMMessage]: + """Get at most `buffer_size` recent messages.""" + messages = self._messages[-self._buffer_size :] + # Handle the first message is a function call result message. + if messages and isinstance(messages[0], FunctionExecutionResultMessage): + # Remove the first message from the list. + messages = messages[1:] + return messages + + def _to_config(self) -> BufferedChatCompletionContextConfig: + return BufferedChatCompletionContextConfig( + buffer_size=self._buffer_size, initial_messages=self._initial_messages + ) + + @classmethod + def _from_config(cls, config: BufferedChatCompletionContextConfig) -> Self: + return cls(**config.model_dump()) diff --git a/agent_dhal/agentdhal_core/model_context/_chat_completion_context.py b/agent_dhal/agentdhal_core/model_context/_chat_completion_context.py new file mode 100644 index 0000000..1524a1e --- /dev/null +++ b/agent_dhal/agentdhal_core/model_context/_chat_completion_context.py @@ -0,0 +1,74 @@ +from abc import ABC, abstractmethod +from typing import Any, List, Mapping + +from pydantic import BaseModel, Field + +from .._component_config import ComponentBase +from ..models import LLMMessage + + +class ChatCompletionContext(ABC, ComponentBase[BaseModel]): + """An abstract base class for defining the interface of a chat completion context. + A chat completion context lets agents store and retrieve LLM messages. + It can be implemented with different recall strategies. + + Args: + initial_messages (List[LLMMessage] | None): The initial messages. + + Example: + + To create a custom model context that filters out the thought field from AssistantMessage. + This is useful for reasoning models like DeepSeek R1, which produces + very long thought that is not needed for subsequent completions. + + .. code-block:: python + + from typing import List + + from agentdhal_core.model_context import UnboundedChatCompletionContext + from agentdhal_core.models import AssistantMessage, LLMMessage + + + class ReasoningModelContext(UnboundedChatCompletionContext): + \"\"\"A model context for reasoning models.\"\"\" + + async def get_messages(self) -> List[LLMMessage]: + messages = await super().get_messages() + # Filter out thought field from AssistantMessage. + messages_out: List[LLMMessage] = [] + for message in messages: + if isinstance(message, AssistantMessage): + message.thought = None + messages_out.append(message) + return messages_out + + """ + + component_type = "chat_completion_context" + + def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None: + self._messages: List[LLMMessage] = [] + if initial_messages is not None: + self._messages.extend(initial_messages) + self._initial_messages = initial_messages + + async def add_message(self, message: LLMMessage) -> None: + """Add a message to the context.""" + self._messages.append(message) + + @abstractmethod + async def get_messages(self) -> List[LLMMessage]: ... + + async def clear(self) -> None: + """Clear the context.""" + self._messages = [] + + async def save_state(self) -> Mapping[str, Any]: + return ChatCompletionContextState(messages=self._messages).model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + self._messages = ChatCompletionContextState.model_validate(state).messages + + +class ChatCompletionContextState(BaseModel): + messages: List[LLMMessage] = Field(default_factory=list) diff --git a/agent_dhal/agentdhal_core/model_context/_head_and_tail_chat_completion_context.py b/agent_dhal/agentdhal_core/model_context/_head_and_tail_chat_completion_context.py new file mode 100644 index 0000000..c9a10fd --- /dev/null +++ b/agent_dhal/agentdhal_core/model_context/_head_and_tail_chat_completion_context.py @@ -0,0 +1,76 @@ +from typing import List + +from pydantic import BaseModel +from typing_extensions import Self + +from .._component_config import Component +from .._types import FunctionCall +from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage +from ._chat_completion_context import ChatCompletionContext + + +class HeadAndTailChatCompletionContextConfig(BaseModel): + head_size: int + tail_size: int + initial_messages: List[LLMMessage] | None = None + + +class HeadAndTailChatCompletionContext(ChatCompletionContext, Component[HeadAndTailChatCompletionContextConfig]): + """A chat completion context that keeps a view of the first n and last m messages, + where n is the head size and m is the tail size. The head and tail sizes + are set at initialization. + + Args: + head_size (int): The size of the head. + tail_size (int): The size of the tail. + initial_messages (List[LLMMessage] | None): The initial messages. + """ + + component_config_schema = HeadAndTailChatCompletionContextConfig + component_provider_override = "agentdhal_core.model_context.HeadAndTailChatCompletionContext" + + def __init__(self, head_size: int, tail_size: int, initial_messages: List[LLMMessage] | None = None) -> None: + super().__init__(initial_messages) + if head_size <= 0: + raise ValueError("head_size must be greater than 0.") + if tail_size <= 0: + raise ValueError("tail_size must be greater than 0.") + self._head_size = head_size + self._tail_size = tail_size + + async def get_messages(self) -> List[LLMMessage]: + """Get at most `head_size` recent messages and `tail_size` oldest messages.""" + head_messages = self._messages[: self._head_size] + # Handle the last message is a function call message. + if ( + head_messages + and isinstance(head_messages[-1], AssistantMessage) + and isinstance(head_messages[-1].content, list) + and all(isinstance(item, FunctionCall) for item in head_messages[-1].content) + ): + # Remove the last message from the head. + head_messages = head_messages[:-1] + + tail_messages = self._messages[-self._tail_size :] + # Handle the first message is a function call result message. + if tail_messages and isinstance(tail_messages[0], FunctionExecutionResultMessage): + # Remove the first message from the tail. + tail_messages = tail_messages[1:] + + num_skipped = len(self._messages) - self._head_size - self._tail_size + if num_skipped <= 0: + # If there are not enough messages to fill the head and tail, + # return all messages. + return self._messages + + placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")] + return head_messages + placeholder_messages + tail_messages + + def _to_config(self) -> HeadAndTailChatCompletionContextConfig: + return HeadAndTailChatCompletionContextConfig( + head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._initial_messages + ) + + @classmethod + def _from_config(cls, config: HeadAndTailChatCompletionContextConfig) -> Self: + return cls(head_size=config.head_size, tail_size=config.tail_size, initial_messages=config.initial_messages) diff --git a/agent_dhal/agentdhal_core/model_context/_token_limited_chat_completion_context.py b/agent_dhal/agentdhal_core/model_context/_token_limited_chat_completion_context.py new file mode 100644 index 0000000..1280ee7 --- /dev/null +++ b/agent_dhal/agentdhal_core/model_context/_token_limited_chat_completion_context.py @@ -0,0 +1,94 @@ +from typing import List + +from pydantic import BaseModel +from typing_extensions import Self + +from .._component_config import Component, ComponentModel +from ..models import ChatCompletionClient, FunctionExecutionResultMessage, LLMMessage +from ..tools import ToolSchema +from ._chat_completion_context import ChatCompletionContext + + +class TokenLimitedChatCompletionContextConfig(BaseModel): + model_client: ComponentModel + token_limit: int | None = None + tool_schema: List[ToolSchema] | None = None + initial_messages: List[LLMMessage] | None = None + + +class TokenLimitedChatCompletionContext(ChatCompletionContext, Component[TokenLimitedChatCompletionContextConfig]): + """(Experimental) A token based chat completion context maintains a view of the context up to a token limit. + + .. note:: + + Added in v0.4.10. This is an experimental component and may change in the future. + + Args: + model_client (ChatCompletionClient): The model client to use for token counting. + The model client must implement the :meth:`~agentdhal_core.models.ChatCompletionClient.count_tokens` + and :meth:`~agentdhal_core.models.ChatCompletionClient.remaining_tokens` methods. + token_limit (int | None): The maximum number of tokens to keep in the context + using the :meth:`~agentdhal_core.models.ChatCompletionClient.count_tokens` method. + If None, the context will be limited by the model client using the + :meth:`~agentdhal_core.models.ChatCompletionClient.remaining_tokens` method. + tools (List[ToolSchema] | None): A list of tool schema to use in the context. + initial_messages (List[LLMMessage] | None): A list of initial messages to include in the context. + + """ + + component_config_schema = TokenLimitedChatCompletionContextConfig + component_provider_override = "agentdhal_core.model_context.TokenLimitedChatCompletionContext" + + def __init__( + self, + model_client: ChatCompletionClient, + *, + token_limit: int | None = None, + tool_schema: List[ToolSchema] | None = None, + initial_messages: List[LLMMessage] | None = None, + ) -> None: + super().__init__(initial_messages) + if token_limit is not None and token_limit <= 0: + raise ValueError("token_limit must be greater than 0.") + self._token_limit = token_limit + self._model_client = model_client + self._tool_schema = tool_schema or [] + + async def get_messages(self) -> List[LLMMessage]: + """Get at most `token_limit` tokens in recent messages. If the token limit is not + provided, then return as many messages as the remaining token allowed by the model client.""" + messages = list(self._messages) + if self._token_limit is None: + remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema) + while remaining_tokens < 0 and len(messages) > 0: + middle_index = len(messages) // 2 + messages.pop(middle_index) + remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema) + else: + token_count = self._model_client.count_tokens(messages, tools=self._tool_schema) + while token_count > self._token_limit and len(messages) > 0: + middle_index = len(messages) // 2 + messages.pop(middle_index) + token_count = self._model_client.count_tokens(messages, tools=self._tool_schema) + if messages and isinstance(messages[0], FunctionExecutionResultMessage): + # Handle the first message is a function call result message. + # Remove the first message from the list. + messages = messages[1:] + return messages + + def _to_config(self) -> TokenLimitedChatCompletionContextConfig: + return TokenLimitedChatCompletionContextConfig( + model_client=self._model_client.dump_component(), + token_limit=self._token_limit, + tool_schema=self._tool_schema, + initial_messages=self._initial_messages, + ) + + @classmethod + def _from_config(cls, config: TokenLimitedChatCompletionContextConfig) -> Self: + return cls( + model_client=ChatCompletionClient.load_component(config.model_client), + token_limit=config.token_limit, + tool_schema=config.tool_schema, + initial_messages=config.initial_messages, + ) diff --git a/agent_dhal/agentdhal_core/model_context/_unbounded_chat_completion_context.py b/agent_dhal/agentdhal_core/model_context/_unbounded_chat_completion_context.py new file mode 100644 index 0000000..98f9552 --- /dev/null +++ b/agent_dhal/agentdhal_core/model_context/_unbounded_chat_completion_context.py @@ -0,0 +1,30 @@ +from typing import List + +from pydantic import BaseModel +from typing_extensions import Self + +from .._component_config import Component +from ..models import LLMMessage +from ._chat_completion_context import ChatCompletionContext + + +class UnboundedChatCompletionContextConfig(BaseModel): + initial_messages: List[LLMMessage] | None = None + + +class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedChatCompletionContextConfig]): + """An unbounded chat completion context that keeps a view of the all the messages.""" + + component_config_schema = UnboundedChatCompletionContextConfig + component_provider_override = "agentdhal_core.model_context.UnboundedChatCompletionContext" + + async def get_messages(self) -> List[LLMMessage]: + """Get at most `buffer_size` recent messages.""" + return self._messages + + def _to_config(self) -> UnboundedChatCompletionContextConfig: + return UnboundedChatCompletionContextConfig(initial_messages=self._initial_messages) + + @classmethod + def _from_config(cls, config: UnboundedChatCompletionContextConfig) -> Self: + return cls(initial_messages=config.initial_messages) diff --git a/agent_dhal/agentdhal_core/py.typed b/agent_dhal/agentdhal_core/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/agent_dhal/agentdhal_core/tool_agent/__init__.py b/agent_dhal/agentdhal_core/tool_agent/__init__.py new file mode 100644 index 0000000..e072efd --- /dev/null +++ b/agent_dhal/agentdhal_core/tool_agent/__init__.py @@ -0,0 +1,17 @@ +from ._caller_loop import tool_agent_caller_loop +from ._tool_agent import ( + InvalidToolArgumentsException, + ToolAgent, + ToolException, + ToolExecutionException, + ToolNotFoundException, +) + +__all__ = [ + "ToolAgent", + "ToolException", + "ToolNotFoundException", + "InvalidToolArgumentsException", + "ToolExecutionException", + "tool_agent_caller_loop", +] diff --git a/agent_dhal/agentdhal_core/tool_agent/_caller_loop.py b/agent_dhal/agentdhal_core/tool_agent/_caller_loop.py new file mode 100644 index 0000000..e5d64c3 --- /dev/null +++ b/agent_dhal/agentdhal_core/tool_agent/_caller_loop.py @@ -0,0 +1,80 @@ +import asyncio +from typing import List + +from .. import AgentId, AgentRuntime, BaseAgent, CancellationToken, FunctionCall +from ..models import ( + AssistantMessage, + ChatCompletionClient, + FunctionExecutionResult, + FunctionExecutionResultMessage, + LLMMessage, +) +from ..tools import Tool, ToolSchema +from ._tool_agent import ToolException + + +async def tool_agent_caller_loop( + caller: BaseAgent | AgentRuntime, + tool_agent_id: AgentId, + model_client: ChatCompletionClient, + input_messages: List[LLMMessage], + tool_schema: List[ToolSchema] | List[Tool], + cancellation_token: CancellationToken | None = None, + caller_source: str = "assistant", +) -> List[LLMMessage]: + """Start a caller loop for a tool agent. This function sends messages to the tool agent + and the model client in an alternating fashion until the model client stops generating tool calls. + + Args: + tool_agent_id (AgentId): The Agent ID of the tool agent. + input_messages (List[LLMMessage]): The list of input messages. + model_client (ChatCompletionClient): The model client to use for the model API. + tool_schema (List[Tool | ToolSchema]): The list of tools that the model can use. + + Returns: + List[LLMMessage]: The list of output messages created in the caller loop. + """ + + generated_messages: List[LLMMessage] = [] + + # Get a response from the model. + response = await model_client.create(input_messages, tools=tool_schema, cancellation_token=cancellation_token) + # Add the response to the generated messages. + generated_messages.append(AssistantMessage(content=response.content, source=caller_source)) + + # Keep iterating until the model stops generating tool calls. + while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content): + # Execute functions called by the model by sending messages to tool agent. + results: List[FunctionExecutionResult | BaseException] = await asyncio.gather( + *[ + caller.send_message( + message=call, + recipient=tool_agent_id, + cancellation_token=cancellation_token, + ) + for call in response.content + ], + return_exceptions=True, + ) + # Combine the results into a single response and handle exceptions. + function_results: List[FunctionExecutionResult] = [] + for result in results: + if isinstance(result, FunctionExecutionResult): + function_results.append(result) + elif isinstance(result, ToolException): + function_results.append( + FunctionExecutionResult( + content=f"Error: {result}", call_id=result.call_id, is_error=True, name=result.name + ) + ) + elif isinstance(result, BaseException): + raise result # Unexpected exception. + generated_messages.append(FunctionExecutionResultMessage(content=function_results)) + # Query the model again with the new response. + response = await model_client.create( + input_messages + generated_messages, tools=tool_schema, cancellation_token=cancellation_token + ) + generated_messages.append(AssistantMessage(content=response.content, source=caller_source)) + + # Return the generated messages. + return generated_messages diff --git a/agent_dhal/agentdhal_core/tool_agent/_tool_agent.py b/agent_dhal/agentdhal_core/tool_agent/_tool_agent.py new file mode 100644 index 0000000..2ddb8dc --- /dev/null +++ b/agent_dhal/agentdhal_core/tool_agent/_tool_agent.py @@ -0,0 +1,96 @@ +import json +from dataclasses import dataclass +from typing import List + +from .. import FunctionCall, MessageContext, RoutedAgent, message_handler +from ..models import FunctionExecutionResult +from ..tools import Tool + +__all__ = [ + "ToolAgent", + "ToolException", + "ToolNotFoundException", + "InvalidToolArgumentsException", + "ToolExecutionException", +] + + +@dataclass +class ToolException(BaseException): + call_id: str + content: str + name: str + + +@dataclass +class ToolNotFoundException(ToolException): + pass + + +@dataclass +class InvalidToolArgumentsException(ToolException): + pass + + +@dataclass +class ToolExecutionException(ToolException): + pass + + +class ToolAgent(RoutedAgent): + """A tool agent accepts direct messages of the type `FunctionCall`, + executes the requested tool with the provided arguments, and returns the + result as `FunctionExecutionResult` messages. + + Args: + description (str): The description of the agent. + tools (List[Tool]): The list of tools that the agent can execute. + """ + + def __init__( + self, + description: str, + tools: List[Tool], + ) -> None: + super().__init__(description) + self._tools = tools + + @property + def tools(self) -> List[Tool]: + return self._tools + + @message_handler + async def handle_function_call(self, message: FunctionCall, ctx: MessageContext) -> FunctionExecutionResult: + """Handles a `FunctionCall` message by executing the requested tool with the provided arguments. + + Args: + message (FunctionCall): The function call message. + cancellation_token (CancellationToken): The cancellation token. + + Returns: + FunctionExecutionResult: The result of the function execution. + + Raises: + ToolNotFoundException: If the tool is not found. + InvalidToolArgumentsException: If the tool arguments are invalid. + ToolExecutionException: If the tool execution fails. + """ + tool = next((tool for tool in self._tools if tool.name == message.name), None) + if tool is None: + raise ToolNotFoundException( + call_id=message.id, content=f"Error: Tool not found: {message.name}", name=message.name + ) + else: + try: + arguments = json.loads(message.arguments) + result = await tool.run_json( + args=arguments, cancellation_token=ctx.cancellation_token, call_id=message.id + ) + result_as_str = tool.return_value_as_string(result) + except json.JSONDecodeError as e: + raise InvalidToolArgumentsException( + call_id=message.id, content=f"Error: Invalid arguments: {message.arguments}", name=message.name + ) from e + except Exception as e: + raise ToolExecutionException(call_id=message.id, content=f"Error: {e}", name=message.name) from e + return FunctionExecutionResult(content=result_as_str, call_id=message.id, is_error=False, name=message.name) diff --git a/agent_dhal/agentdhal_core/tools/__init__.py b/agent_dhal/agentdhal_core/tools/__init__.py new file mode 100644 index 0000000..aee634e --- /dev/null +++ b/agent_dhal/agentdhal_core/tools/__init__.py @@ -0,0 +1,31 @@ +from ._base import ( + BaseStreamTool, + BaseTool, + BaseToolWithState, + ParametersSchema, + StreamTool, + Tool, + ToolOverride, + ToolSchema, +) +from ._function_tool import FunctionTool +from ._static_workbench import StaticStreamWorkbench, StaticWorkbench +from ._workbench import ImageResultContent, TextResultContent, ToolResult, Workbench + +__all__ = [ + "Tool", + "StreamTool", + "ToolSchema", + "ParametersSchema", + "BaseTool", + "BaseToolWithState", + "BaseStreamTool", + "FunctionTool", + "Workbench", + "ToolResult", + "TextResultContent", + "ImageResultContent", + "StaticWorkbench", + "StaticStreamWorkbench", + "ToolOverride", +] diff --git a/agent_dhal/agentdhal_core/tools/_base.py b/agent_dhal/agentdhal_core/tools/_base.py new file mode 100644 index 0000000..d2ea76e --- /dev/null +++ b/agent_dhal/agentdhal_core/tools/_base.py @@ -0,0 +1,294 @@ +import json +import logging +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import ( + Any, + AsyncGenerator, + Dict, + Generic, + Mapping, + Optional, + Protocol, + Type, + TypeVar, + cast, + runtime_checkable, +) + +import jsonref +from pydantic import BaseModel +from typing_extensions import NotRequired, TypedDict + +from .. import EVENT_LOGGER_NAME, CancellationToken +from .._component_config import ComponentBase +from .._function_utils import normalize_annotated_type +from .._telemetry import trace_tool_span +from ..logging import ToolCallEvent + +T = TypeVar("T", bound=BaseModel, contravariant=True) + +logger = logging.getLogger(EVENT_LOGGER_NAME) + + +class ParametersSchema(TypedDict): + type: str + properties: Dict[str, Any] + required: NotRequired[Sequence[str]] + additionalProperties: NotRequired[bool] + + +class ToolSchema(TypedDict): + parameters: NotRequired[ParametersSchema] + name: str + description: NotRequired[str] + strict: NotRequired[bool] + + +class ToolOverride(BaseModel): + """Override configuration for a tool's name and/or description.""" + + name: Optional[str] = None + description: Optional[str] = None + + +@runtime_checkable +class Tool(Protocol): + @property + def name(self) -> str: ... + + @property + def description(self) -> str: ... + + @property + def schema(self) -> ToolSchema: ... + + def args_type(self) -> Type[BaseModel]: ... + + def return_type(self) -> Type[Any]: ... + + def state_type(self) -> Type[BaseModel] | None: ... + + def return_value_as_string(self, value: Any) -> str: ... + + async def run_json( + self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None + ) -> Any: ... + + async def save_state_json(self) -> Mapping[str, Any]: ... + + async def load_state_json(self, state: Mapping[str, Any]) -> None: ... + + +@runtime_checkable +class StreamTool(Tool, Protocol): + def run_json_stream( + self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None + ) -> AsyncGenerator[Any, None]: ... + + +ArgsT = TypeVar("ArgsT", bound=BaseModel, contravariant=True) +ReturnT = TypeVar("ReturnT", bound=BaseModel, covariant=True) +StateT = TypeVar("StateT", bound=BaseModel) +StreamT = TypeVar("StreamT", bound=BaseModel, covariant=True) + + +class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT], ComponentBase[BaseModel]): + component_type = "tool" + + def __init__( + self, + args_type: Type[ArgsT], + return_type: Type[ReturnT], + name: str, + description: str, + strict: bool = False, + ) -> None: + self._args_type = args_type + # Normalize Annotated to the base type. + self._return_type = normalize_annotated_type(return_type) + self._name = name + self._description = description + self._strict = strict + + @property + def schema(self) -> ToolSchema: + model_schema: Dict[str, Any] = self._args_type.model_json_schema() + + if "$defs" in model_schema: + model_schema = cast(Dict[str, Any], jsonref.replace_refs(obj=model_schema, proxies=False)) # type: ignore + del model_schema["$defs"] + + parameters = ParametersSchema( + type="object", + properties=model_schema["properties"], + required=model_schema.get("required", []), + additionalProperties=model_schema.get("additionalProperties", False), + ) + + # If strict is enabled, the tool schema should list all properties as required. + assert "required" in parameters + if self._strict and set(parameters["required"]) != set(parameters["properties"].keys()): + raise ValueError( + "Strict mode is enabled, but not all input arguments are marked as required. Default arguments are not allowed in strict mode." + ) + + assert "additionalProperties" in parameters + if self._strict and parameters["additionalProperties"]: + raise ValueError( + "Strict mode is enabled but additional argument is also enabled. This is not allowed in strict mode." + ) + + tool_schema = ToolSchema( + name=self._name, + description=self._description, + parameters=parameters, + strict=self._strict, + ) + return tool_schema + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return self._description + + def args_type(self) -> Type[BaseModel]: + return self._args_type + + def return_type(self) -> Type[Any]: + return self._return_type + + def state_type(self) -> Type[BaseModel] | None: + return None + + def return_value_as_string(self, value: Any) -> str: + if isinstance(value, BaseModel): + dumped = value.model_dump() + if isinstance(dumped, dict): + return json.dumps(dumped) + return str(dumped) + + return str(value) + + @abstractmethod + async def run(self, args: ArgsT, cancellation_token: CancellationToken) -> ReturnT: ... + + async def run_json( + self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None + ) -> Any: + """Run the tool with the provided arguments in a dictionary. + + Args: + args (Mapping[str, Any]): The arguments to pass to the tool. + cancellation_token (CancellationToken): A token to cancel the operation if needed. + call_id (str | None): An optional identifier for the tool call, used for tracing. + + Returns: + Any: The return value of the tool's run method. + """ + with trace_tool_span( + tool_name=self._name, + tool_description=self._description, + tool_call_id=call_id, + ): + # Execute the tool's run method + return_value = await self.run(self._args_type.model_validate(args), cancellation_token) + + # Log the tool call event + event = ToolCallEvent( + tool_name=self.name, + arguments=dict(args), # Using the raw args passed to run_json + result=self.return_value_as_string(return_value), + ) + logger.info(event) + + return return_value + + async def save_state_json(self) -> Mapping[str, Any]: + return {} + + async def load_state_json(self, state: Mapping[str, Any]) -> None: + pass + + +class BaseStreamTool( + BaseTool[ArgsT, ReturnT], StreamTool, ABC, Generic[ArgsT, StreamT, ReturnT], ComponentBase[BaseModel] +): + component_type = "tool" + + @abstractmethod + def run_stream(self, args: ArgsT, cancellation_token: CancellationToken) -> AsyncGenerator[StreamT | ReturnT, None]: + """Run the tool with the provided arguments and return a stream of data and end with the final return value.""" + ... + + async def run_json_stream( + self, + args: Mapping[str, Any], + cancellation_token: CancellationToken, + call_id: str | None = None, + ) -> AsyncGenerator[StreamT | ReturnT, None]: + """Run the tool with the provided arguments in a dictionary and return a stream of data + from the tool's :meth:`run_stream` method and end with the final return value. + + Args: + args (Mapping[str, Any]): The arguments to pass to the tool. + cancellation_token (CancellationToken): A token to cancel the operation if needed. + call_id (str | None): An optional identifier for the tool call, used for tracing. + + Returns: + AsyncGenerator[StreamT | ReturnT, None]: A generator yielding results from the tool's :meth:`run_stream` method. + """ + return_value: ReturnT | StreamT | None = None + with trace_tool_span( + tool_name=self._name, + tool_description=self._description, + tool_call_id=call_id, + ): + # Execute the tool's run_stream method + async for result in self.run_stream(self._args_type.model_validate(args), cancellation_token): + return_value = result + yield result + + assert return_value is not None, "The tool must yield a final return value at the end of the stream." + if not isinstance(return_value, self._return_type): + raise TypeError( + f"Expected return value of type {self._return_type.__name__}, but got {type(return_value).__name__}" + ) + + # Log the tool call event + event = ToolCallEvent( + tool_name=self.name, + arguments=dict(args), # Using the raw args passed to run_json + result=self.return_value_as_string(return_value), + ) + logger.info(event) + + +class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT], ComponentBase[BaseModel]): + def __init__( + self, + args_type: Type[ArgsT], + return_type: Type[ReturnT], + state_type: Type[StateT], + name: str, + description: str, + ) -> None: + super().__init__(args_type, return_type, name, description) + self._state_type = state_type + + component_type = "tool" + + @abstractmethod + def save_state(self) -> StateT: ... + + @abstractmethod + def load_state(self, state: StateT) -> None: ... + + async def save_state_json(self) -> Mapping[str, Any]: + return self.save_state().model_dump() + + async def load_state_json(self, state: Mapping[str, Any]) -> None: + self.load_state(self._state_type.model_validate(state)) diff --git a/agent_dhal/agentdhal_core/tools/_function_tool.py b/agent_dhal/agentdhal_core/tools/_function_tool.py new file mode 100644 index 0000000..5f88a9a --- /dev/null +++ b/agent_dhal/agentdhal_core/tools/_function_tool.py @@ -0,0 +1,181 @@ +import asyncio +import functools +import warnings +from textwrap import dedent +from typing import Any, Callable, Sequence + +from pydantic import BaseModel +from typing_extensions import Self + +from .. import CancellationToken +from .._component_config import Component +from .._function_utils import ( + args_base_model_from_signature, + get_typed_signature, +) +from ..code_executor._func_with_reqs import Import, import_to_str, to_code +from ._base import BaseTool + + +class FunctionToolConfig(BaseModel): + """Configuration for a function tool.""" + + source_code: str + name: str + description: str + global_imports: Sequence[Import] + has_cancellation_support: bool + + +class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]): + """ + Create custom tools by wrapping standard Python functions. + + `FunctionTool` offers an interface for executing Python functions either asynchronously or synchronously. + Each function must include type annotations for all parameters and its return type. These annotations + enable `FunctionTool` to generate a schema necessary for input validation, serialization, and for informing + the LLM about expected parameters. When the LLM prepares a function call, it leverages this schema to + generate arguments that align with the function's specifications. + + .. note:: + + It is the user's responsibility to verify that the tool's output type matches the expected type. + + Args: + func (Callable[..., ReturnT | Awaitable[ReturnT]]): The function to wrap and expose as a tool. + description (str): A description to inform the model of the function's purpose, specifying what + it does and the context in which it should be called. + name (str, optional): An optional custom name for the tool. Defaults to + the function's original name if not provided. + strict (bool, optional): If set to True, the tool schema will only contain arguments that are explicitly + defined in the function signature, and no default values will be allowed. Defaults to False. + This is required to be set to True when used with models in structured output mode. + + Example: + + .. code-block:: python + + import random + from agentdhal_core import CancellationToken + from agentdhal_core.tools import FunctionTool + from typing_extensions import Annotated + import asyncio + + + async def get_stock_price(ticker: str, date: Annotated[str, "Date in YYYY/MM/DD"]) -> float: + # Simulates a stock price retrieval by returning a random float within a specified range. + return random.uniform(10, 200) + + + async def example(): + # Initialize a FunctionTool instance for retrieving stock prices. + stock_price_tool = FunctionTool(get_stock_price, description="Fetch the stock price for a given ticker.") + + # Execute the tool with cancellation support. + cancellation_token = CancellationToken() + result = await stock_price_tool.run_json({"ticker": "AAPL", "date": "2021/01/01"}, cancellation_token) + + # Output the result as a formatted string. + print(stock_price_tool.return_value_as_string(result)) + + + asyncio.run(example()) + """ + + component_provider_override = "agentdhal_core.tools.FunctionTool" + component_config_schema = FunctionToolConfig + + def __init__( + self, + func: Callable[..., Any], + description: str, + name: str | None = None, + global_imports: Sequence[Import] = [], + strict: bool = False, + ) -> None: + self._func = func + self._global_imports = global_imports + self._signature = get_typed_signature(func) + func_name = name or func.func.__name__ if isinstance(func, functools.partial) else name or func.__name__ + args_model = args_base_model_from_signature(func_name + "args", self._signature) + self._has_cancellation_support = "cancellation_token" in self._signature.parameters + return_type = self._signature.return_annotation + super().__init__(args_model, return_type, func_name, description, strict) + + async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any: + kwargs = {} + + for name in self._signature.parameters.keys(): + if hasattr(args, name): + kwargs[name] = getattr(args, name) + + if asyncio.iscoroutinefunction(self._func): + if self._has_cancellation_support: + result = await self._func(**kwargs, cancellation_token=cancellation_token) + else: + result = await self._func(**kwargs) + else: + if self._has_cancellation_support: + result = await asyncio.get_event_loop().run_in_executor( + None, + functools.partial( + self._func, + **kwargs, + cancellation_token=cancellation_token, + ), + ) + else: + future = asyncio.get_event_loop().run_in_executor(None, functools.partial(self._func, **kwargs)) + cancellation_token.link_future(future) + result = await future + + return result + + def _to_config(self) -> FunctionToolConfig: + return FunctionToolConfig( + source_code=dedent(to_code(self._func)), + global_imports=self._global_imports, + name=self.name, + description=self.description, + has_cancellation_support=self._has_cancellation_support, + ) + + @classmethod + def _from_config(cls, config: FunctionToolConfig) -> Self: + warnings.warn( + "\n⚠️ SECURITY WARNING ⚠️\n" + "Loading a FunctionTool from config will execute code to import the provided global imports and and function code.\n" + "Only load configs from TRUSTED sources to prevent arbitrary code execution.", + UserWarning, + stacklevel=2, + ) + + exec_globals: dict[str, Any] = {} + + # Execute imports first + for import_stmt in config.global_imports: + import_code = import_to_str(import_stmt) + try: + exec(import_code, exec_globals) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Failed to import {import_code}: Module not found. Please ensure the module is installed." + ) from e + except ImportError as e: + raise ImportError(f"Failed to import {import_code}: {str(e)}") from e + except Exception as e: + raise RuntimeError(f"Unexpected error while importing {import_code}: {str(e)}") from e + + # Execute function code + try: + exec(config.source_code, exec_globals) + func_name = config.source_code.split("def ")[1].split("(")[0] + except Exception as e: + raise ValueError(f"Could not compile and load function: {e}") from e + + # Get function and verify it's callable + func: Callable[..., Any] = exec_globals[func_name] + if not callable(func): + raise TypeError(f"Expected function but got {type(func)}") + + return cls(func, name=config.name, description=config.description, global_imports=config.global_imports) diff --git a/agent_dhal/agentdhal_core/tools/_static_workbench.py b/agent_dhal/agentdhal_core/tools/_static_workbench.py new file mode 100644 index 0000000..e44a211 --- /dev/null +++ b/agent_dhal/agentdhal_core/tools/_static_workbench.py @@ -0,0 +1,225 @@ +import asyncio +import builtins +from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional + +from pydantic import BaseModel, Field +from typing_extensions import Self + +from .._cancellation_token import CancellationToken +from .._component_config import Component, ComponentModel +from ._base import BaseTool, StreamTool, ToolOverride, ToolSchema +from ._workbench import StreamWorkbench, TextResultContent, ToolResult, Workbench + + +class StaticWorkbenchConfig(BaseModel): + tools: List[ComponentModel] = [] + tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict) + + +class StateicWorkbenchState(BaseModel): + type: Literal["StaticWorkbenchState"] = "StaticWorkbenchState" + tools: Dict[str, Mapping[str, Any]] = {} + + +class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]): + """ + A workbench that provides a static set of tools that do not change after + each tool execution. + + Args: + tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench. + The tools should be subclasses of :class:`~agentdhal_core.tools.BaseTool`. + tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool + names to override configurations for name and/or description. This allows + customizing how tools appear to consumers while maintaining the underlying + tool functionality. + """ + + component_provider_override = "agentdhal_core.tools.StaticWorkbench" + component_config_schema = StaticWorkbenchConfig + + def __init__( + self, tools: List[BaseTool[Any, Any]], tool_overrides: Optional[Dict[str, ToolOverride]] = None + ) -> None: + self._tools = tools + self._tool_overrides = tool_overrides or {} + + # Build reverse mapping from override names to original names for call_tool + self._override_name_to_original: Dict[str, str] = {} + existing_tool_names = {tool.name for tool in self._tools} + + for original_name, override in self._tool_overrides.items(): + if override.name and override.name != original_name: + # Check for conflicts with existing tool names + if override.name in existing_tool_names and override.name != original_name: + raise ValueError( + f"Tool override name '{override.name}' conflicts with existing tool name. " + f"Override names must not conflict with any tool names." + ) + # Check for conflicts with other override names + if override.name in self._override_name_to_original: + existing_original = self._override_name_to_original[override.name] + raise ValueError( + f"Tool override name '{override.name}' is used by multiple tools: " + f"'{existing_original}' and '{original_name}'. Override names must be unique." + ) + self._override_name_to_original[override.name] = original_name + + async def list_tools(self) -> List[ToolSchema]: + result_schemas: List[ToolSchema] = [] + for tool in self._tools: + original_schema = tool.schema + + # Apply overrides if they exist for this tool + if tool.name in self._tool_overrides: + override = self._tool_overrides[tool.name] + # Create a new ToolSchema with overrides applied + schema: ToolSchema = { + "name": override.name if override.name is not None else original_schema["name"], + "description": override.description + if override.description is not None + else original_schema.get("description", ""), + } + # Copy optional fields + if "parameters" in original_schema: + schema["parameters"] = original_schema["parameters"] + if "strict" in original_schema: + schema["strict"] = original_schema["strict"] + else: + schema = original_schema + + result_schemas.append(schema) + return result_schemas + + async def call_tool( + self, + name: str, + arguments: Mapping[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + call_id: str | None = None, + ) -> ToolResult: + # Check if the name is an override name and map it back to the original + original_name = self._override_name_to_original.get(name, name) + + tool = next((tool for tool in self._tools if tool.name == original_name), None) + if tool is None: + return ToolResult( + name=name, # Return the requested name (which might be overridden) + result=[TextResultContent(content=f"Tool {name} not found.")], + is_error=True, + ) + if not cancellation_token: + cancellation_token = CancellationToken() + if not arguments: + arguments = {} + try: + result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id)) + cancellation_token.link_future(result_future) + actual_tool_output = await result_future + is_error = False + result_str = tool.return_value_as_string(actual_tool_output) + except Exception as e: + result_str = self._format_errors(e) + is_error = True + return ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=is_error) + + async def start(self) -> None: + return None + + async def stop(self) -> None: + return None + + async def reset(self) -> None: + return None + + async def save_state(self) -> Mapping[str, Any]: + tool_states = StateicWorkbenchState() + for tool in self._tools: + tool_states.tools[tool.name] = await tool.save_state_json() + return tool_states.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + parsed_state = StateicWorkbenchState.model_validate(state) + for tool in self._tools: + if tool.name in parsed_state.tools: + await tool.load_state_json(parsed_state.tools[tool.name]) + + def _to_config(self) -> StaticWorkbenchConfig: + return StaticWorkbenchConfig( + tools=[tool.dump_component() for tool in self._tools], tool_overrides=self._tool_overrides + ) + + @classmethod + def _from_config(cls, config: StaticWorkbenchConfig) -> Self: + return cls(tools=[BaseTool.load_component(tool) for tool in config.tools], tool_overrides=config.tool_overrides) + + def _format_errors(self, error: Exception) -> str: + """Recursively format errors into a string.""" + + error_message = "" + if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup): + # ExceptionGroup is available in Python 3.11+. + # TODO: how to make this compatible with Python 3.10? + for sub_exception in error.exceptions: # type: ignore + error_message += self._format_errors(sub_exception) # type: ignore + else: + error_message += f"{str(error)}\n" + return error_message.strip() + + +class StaticStreamWorkbench(StaticWorkbench, StreamWorkbench): + """ + A workbench that provides a static set of tools that do not change after + each tool execution, and supports streaming results. + """ + + component_provider_override = "agentdhal_core.tools.StaticStreamWorkbench" + + async def call_tool_stream( + self, + name: str, + arguments: Mapping[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + call_id: str | None = None, + ) -> AsyncGenerator[Any | ToolResult, None]: + tool = next((tool for tool in self._tools if tool.name == name), None) + if tool is None: + yield ToolResult( + name=name, + result=[TextResultContent(content=f"Tool {name} not found.")], + is_error=True, + ) + return + if not cancellation_token: + cancellation_token = CancellationToken() + if not arguments: + arguments = {} + try: + actual_tool_output: Any | None = None + if isinstance(tool, StreamTool): + previous_result: Any | None = None + try: + async for result in tool.run_json_stream(arguments, cancellation_token, call_id=call_id): + if previous_result is not None: + yield previous_result + previous_result = result + actual_tool_output = previous_result + except Exception as e: + # If there was a previous result before the exception, yield it first + if previous_result is not None: + yield previous_result + # Then yield the error result + result_str = self._format_errors(e) + yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=True) + return + else: + # If the tool is not a stream tool, we run it normally and yield the result + result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id)) + cancellation_token.link_future(result_future) + actual_tool_output = await result_future + is_error = False + result_str = tool.return_value_as_string(actual_tool_output) + except Exception as e: + result_str = self._format_errors(e) + is_error = True + yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error) diff --git a/agent_dhal/agentdhal_core/tools/_workbench.py b/agent_dhal/agentdhal_core/tools/_workbench.py new file mode 100644 index 0000000..faedb70 --- /dev/null +++ b/agent_dhal/agentdhal_core/tools/_workbench.py @@ -0,0 +1,216 @@ +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Type + +from pydantic import BaseModel, Field +from typing_extensions import Annotated, Self + +from .._cancellation_token import CancellationToken +from .._component_config import ComponentBase +from .._image import Image +from ._base import ToolSchema + + +class TextResultContent(BaseModel): + """ + Text result content of a tool execution. + """ + + type: Literal["TextResultContent"] = "TextResultContent" + + content: str + """The text content of the result.""" + + +class ImageResultContent(BaseModel): + """ + Image result content of a tool execution. + """ + + type: Literal["ImageResultContent"] = "ImageResultContent" + + content: Image + """The image content of the result.""" + + +ResultContent = Annotated[TextResultContent | ImageResultContent, Field(discriminator="type")] + + +class ToolResult(BaseModel): + """ + A result of a tool execution by a workbench. + """ + + type: Literal["ToolResult"] = "ToolResult" + + name: str + """The name of the tool that was executed.""" + + result: List[ResultContent] + """The result of the tool execution.""" + + is_error: bool = False + """Whether the tool execution resulted in an error.""" + + def to_text(self, replace_image: str | None = None) -> str: + """ + Convert the result to a text string. + + Args: + replace_image (str | None): The string to replace the image content with. + If None, the image content will be included in the text as base64 string. + + Returns: + str: The text representation of the result. + """ + parts: List[str] = [] + for content in self.result: + if isinstance(content, TextResultContent): + parts.append(content.content) + elif isinstance(content, ImageResultContent): + if replace_image is not None: + parts.append(replace_image) + else: + parts.append(f"[Image: {content.content.to_base64()}]") + return "\n".join(parts) + + +class Workbench(ABC, ComponentBase[BaseModel]): + """ + A workbench is a component that provides a set of tools that may share + resources and state. + + A workbench is responsible for managing the lifecycle of the tools and + providing a single interface to call them. The tools provided by the workbench + may be dynamic and their availabilities may change after each tool execution. + + A workbench can be started by calling the :meth:`~agentdhal_core.tools.Workbench.start` method + and stopped by calling the :meth:`~agentdhal_core.tools.Workbench.stop` method. + It can also be used as an asynchronous context manager, which will automatically + start and stop the workbench when entering and exiting the context. + """ + + component_type = "workbench" + + @abstractmethod + async def list_tools(self) -> List[ToolSchema]: + """ + List the currently available tools in the workbench as :class:`ToolSchema` + objects. + + The list of tools may be dynamic, and their content may change after + tool execution. + """ + ... + + @abstractmethod + async def call_tool( + self, + name: str, + arguments: Mapping[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + call_id: str | None = None, + ) -> ToolResult: + """ + Call a tool in the workbench. + + Args: + name (str): The name of the tool to call. + arguments (Mapping[str, Any] | None): The arguments to pass to the tool. + If None, the tool will be called with no arguments. + cancellation_token (CancellationToken | None): An optional cancellation token + to cancel the tool execution. + call_id (str | None): An optional identifier for the tool call, used for tracing. + Returns: + ToolResult: The result of the tool execution. + """ + ... + + @abstractmethod + async def start(self) -> None: + """ + Start the workbench and initialize any resources. + + This method should be called before using the workbench. + """ + ... + + @abstractmethod + async def stop(self) -> None: + """ + Stop the workbench and release any resources. + + This method should be called when the workbench is no longer needed. + """ + ... + + @abstractmethod + async def reset(self) -> None: + """ + Reset the workbench to its initialized, started state. + """ + ... + + @abstractmethod + async def save_state(self) -> Mapping[str, Any]: + """ + Save the state of the workbench. + + This method should be called to persist the state of the workbench. + """ + ... + + @abstractmethod + async def load_state(self, state: Mapping[str, Any]) -> None: + """ + Load the state of the workbench. + + Args: + state (Mapping[str, Any]): The state to load into the workbench. + """ + ... + + async def __aenter__(self) -> Self: + """ + Enter the workbench context manager. + + This method is called when the workbench is used in a `with` statement. + It calls the :meth:`~agentdhal_core.tools.WorkBench.start` method to start the workbench. + """ + await self.start() + return self + + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: + """ + Exit the workbench context manager. + This method is called when the workbench is used in a `with` statement. + It calls the :meth:`~agentdhal_core.tools.WorkBench.stop` method to stop the workbench. + """ + await self.stop() + + +class StreamWorkbench(Workbench, ABC): + """A workbench that supports streaming results from tool calls.""" + + @abstractmethod + def call_tool_stream( + self, + name: str, + arguments: Mapping[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + call_id: str | None = None, + ) -> AsyncGenerator[Any | ToolResult, None]: + """ + Call a tool in the workbench and return a stream of results. + + Args: + name (str): The name of the tool to call. + arguments (Mapping[str, Any] | None): The arguments to pass to the tool + If None, the tool will be called with no arguments. + cancellation_token (CancellationToken | None): An optional cancellation token + to cancel the tool execution. + call_id (str | None): An optional identifier for the tool call, used for tracing. + """ + ... diff --git a/agent_dhal/agentdhal_core/utils/__init__.py b/agent_dhal/agentdhal_core/utils/__init__.py new file mode 100644 index 0000000..e46b756 --- /dev/null +++ b/agent_dhal/agentdhal_core/utils/__init__.py @@ -0,0 +1,4 @@ +from ._json_to_pydantic import schema_to_pydantic_model +from ._load_json import extract_json_from_str + +__all__ = ["schema_to_pydantic_model", "extract_json_from_str"] diff --git a/agent_dhal/agentdhal_core/utils/_json_to_pydantic.py b/agent_dhal/agentdhal_core/utils/_json_to_pydantic.py new file mode 100644 index 0000000..5a401af --- /dev/null +++ b/agent_dhal/agentdhal_core/utils/_json_to_pydantic.py @@ -0,0 +1,567 @@ +import datetime +from ipaddress import IPv4Address, IPv6Address +from typing import Annotated, Any, Dict, ForwardRef, List, Literal, Optional, Type, Union, cast + +from pydantic import ( + UUID1, + UUID3, + UUID4, + UUID5, + AnyUrl, + BaseModel, + EmailStr, + Field, + Json, + conbytes, + confloat, + conint, + conlist, + constr, + create_model, +) +from pydantic.fields import FieldInfo + + +class SchemaConversionError(Exception): + """Base class for schema conversion exceptions.""" + + pass + + +class ReferenceNotFoundError(SchemaConversionError): + """Raised when a $ref cannot be resolved.""" + + pass + + +class FormatNotSupportedError(SchemaConversionError): + """Raised when a format is not supported.""" + + pass + + +class UnsupportedKeywordError(SchemaConversionError): + """Raised when an unsupported JSON Schema keyword is encountered.""" + + pass + + +TYPE_MAPPING: Dict[str, Type[Any]] = { + "string": str, + "integer": int, + "boolean": bool, + "number": float, + "array": List, + "object": dict, + "null": type(None), +} + +FORMAT_MAPPING: Dict[str, Any] = { + "uuid": UUID4, + "uuid1": UUID1, + "uuid2": UUID4, + "uuid3": UUID3, + "uuid4": UUID4, + "uuid5": UUID5, + "email": EmailStr, + "uri": AnyUrl, + "hostname": constr(strict=True), + "ipv4": IPv4Address, + "ipv6": IPv6Address, + "ipv4-network": IPv4Address, + "ipv6-network": IPv6Address, + "date-time": datetime.datetime, + "date": datetime.date, + "time": datetime.time, + "duration": datetime.timedelta, + "int32": conint(strict=True, ge=-(2**31), le=2**31 - 1), + "int64": conint(strict=True, ge=-(2**63), le=2**63 - 1), + "float": confloat(strict=True), + "double": float, + "decimal": float, + "byte": conbytes(strict=True), + "binary": conbytes(strict=True), + "password": str, + "path": str, + "json": Json, +} + + +def _make_field( + default: Any, + *, + title: Optional[str] = None, + description: Optional[str] = None, +) -> Any: + """Construct a Pydantic Field with proper typing.""" + field_kwargs: Dict[str, Any] = {} + if title is not None: + field_kwargs["title"] = title + if description is not None: + field_kwargs["description"] = description + return Field(default, **field_kwargs) + + +class _JSONSchemaToPydantic: + def __init__(self) -> None: + self._model_cache: Dict[str, Optional[Union[Type[BaseModel], ForwardRef]]] = {} + + def _resolve_ref(self, ref: str, schema: Dict[str, Any]) -> Dict[str, Any]: + ref_key = ref.split("/")[-1] + definitions = cast(dict[str, dict[str, Any]], schema.get("$defs", {})) + + if ref_key not in definitions: + raise ReferenceNotFoundError( + f"Reference `{ref}` not found in `$defs`. Available keys: {list(definitions.keys())}" + ) + + return definitions[ref_key] + + def get_ref(self, ref_name: str) -> Any: + if ref_name not in self._model_cache: + raise ReferenceNotFoundError( + f"Reference `{ref_name}` not found in cache. Available: {list(self._model_cache.keys())}" + ) + + if self._model_cache[ref_name] is None: + return ForwardRef(ref_name) + + return self._model_cache[ref_name] + + def _process_definitions(self, root_schema: Dict[str, Any]) -> None: + if "$defs" in root_schema: + for model_name in root_schema["$defs"]: + if model_name not in self._model_cache: + self._model_cache[model_name] = None + + for model_name, model_schema in root_schema["$defs"].items(): + if self._model_cache[model_name] is None: + self._model_cache[model_name] = self.json_schema_to_pydantic(model_schema, model_name, root_schema) + + def json_schema_to_pydantic( + self, schema: Dict[str, Any], model_name: str = "GeneratedModel", root_schema: Optional[Dict[str, Any]] = None + ) -> Type[BaseModel]: + if root_schema is None: + root_schema = schema + self._process_definitions(root_schema) + + if "$ref" in schema: + resolved = self._resolve_ref(schema["$ref"], root_schema) + schema = {**resolved, **{k: v for k, v in schema.items() if k != "$ref"}} + + if "allOf" in schema: + merged: Dict[str, Any] = {"type": "object", "properties": {}, "required": []} + for s in schema["allOf"]: + part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s + merged["properties"].update(part.get("properties", {})) + merged["required"].extend(part.get("required", [])) + for k, v in schema.items(): + if k not in {"allOf", "properties", "required"}: + merged[k] = v + merged["required"] = list(set(merged["required"])) + schema = merged + + return self._json_schema_to_model(schema, model_name, root_schema) + + def _resolve_union_types(self, schemas: List[Dict[str, Any]]) -> List[Any]: + types: List[Any] = [] + for s in schemas: + if "$ref" in s: + types.append(self.get_ref(s["$ref"].split("/")[-1])) + elif "enum" in s: + types.append(Literal[tuple(s["enum"])] if len(s["enum"]) > 0 else Any) + else: + json_type = s.get("type") + if json_type not in TYPE_MAPPING: + raise UnsupportedKeywordError(f"Unsupported or missing type `{json_type}` in union") + + # Handle array types with items specification + if json_type == "array" and "items" in s: + item_schema = s["items"] + if "$ref" in item_schema: + item_type = self.get_ref(item_schema["$ref"].split("/")[-1]) + else: + item_type_name = item_schema.get("type") + if item_type_name is None: + item_type = str + elif item_type_name not in TYPE_MAPPING: + raise UnsupportedKeywordError(f"Unsupported item type `{item_type_name}` in union array") + else: + item_type = TYPE_MAPPING[item_type_name] + + constraints = {} + if "minItems" in s: + constraints["min_length"] = s["minItems"] + if "maxItems" in s: + constraints["max_length"] = s["maxItems"] + + array_type = conlist(item_type, **constraints) if constraints else List[item_type] # type: ignore[valid-type] + types.append(array_type) + else: + types.append(TYPE_MAPPING[json_type]) + return types + + def _extract_field_type(self, key: str, value: Dict[str, Any], model_name: str, root_schema: Dict[str, Any]) -> Any: + json_type = value.get("type") + if json_type not in TYPE_MAPPING: + raise UnsupportedKeywordError( + f"Unsupported or missing type `{json_type}` for field `{key}` in `{model_name}`" + ) + + base_type = TYPE_MAPPING[json_type] + constraints: Dict[str, Any] = {} + + if json_type == "string": + if "minLength" in value: + constraints["min_length"] = value["minLength"] + if "maxLength" in value: + constraints["max_length"] = value["maxLength"] + if "pattern" in value: + constraints["pattern"] = value["pattern"] + if constraints: + base_type = constr(**constraints) + + elif json_type == "integer": + if "minimum" in value: + constraints["ge"] = value["minimum"] + if "maximum" in value: + constraints["le"] = value["maximum"] + if "exclusiveMinimum" in value: + constraints["gt"] = value["exclusiveMinimum"] + if "exclusiveMaximum" in value: + constraints["lt"] = value["exclusiveMaximum"] + if constraints: + base_type = conint(**constraints) + + elif json_type == "number": + if "minimum" in value: + constraints["ge"] = value["minimum"] + if "maximum" in value: + constraints["le"] = value["maximum"] + if "exclusiveMinimum" in value: + constraints["gt"] = value["exclusiveMinimum"] + if "exclusiveMaximum" in value: + constraints["lt"] = value["exclusiveMaximum"] + if constraints: + base_type = confloat(**constraints) + + elif json_type == "array": + if "minItems" in value: + constraints["min_length"] = value["minItems"] + if "maxItems" in value: + constraints["max_length"] = value["maxItems"] + item_schema = value.get("items", {"type": "string"}) + if "$ref" in item_schema: + item_type = self.get_ref(item_schema["$ref"].split("/")[-1]) + else: + item_type_name = item_schema.get("type") + if item_type_name is None: + item_type = str + elif item_type_name not in TYPE_MAPPING: + raise UnsupportedKeywordError( + f"Unsupported or missing item type `{item_type_name}` for array field `{key}` in `{model_name}`" + ) + else: + item_type = TYPE_MAPPING[item_type_name] + + base_type = conlist(item_type, **constraints) if constraints else List[item_type] # type: ignore[valid-type] + + if "format" in value: + format_type = FORMAT_MAPPING.get(value["format"]) + if format_type is None: + raise FormatNotSupportedError(f"Unknown format `{value['format']}` for `{key}` in `{model_name}`") + if not isinstance(format_type, type): + return format_type + if not issubclass(format_type, str): + return format_type + return format_type + + return base_type + + def _json_schema_to_model( + self, schema: Dict[str, Any], model_name: str, root_schema: Dict[str, Any] + ) -> Type[BaseModel]: + if "allOf" in schema: + merged: Dict[str, Any] = {"type": "object", "properties": {}, "required": []} + for s in schema["allOf"]: + part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s + merged["properties"].update(part.get("properties", {})) + merged["required"].extend(part.get("required", [])) + for k, v in schema.items(): + if k not in {"allOf", "properties", "required"}: + merged[k] = v + merged["required"] = list(set(merged["required"])) + schema = merged + + fields: Dict[str, tuple[Any, FieldInfo]] = {} + required_fields = set(schema.get("required", [])) + + for key, value in schema.get("properties", {}).items(): + if "$ref" in value: + ref_name = value["$ref"].split("/")[-1] + field_type = self.get_ref(ref_name) + elif "anyOf" in value: + sub_models = self._resolve_union_types(value["anyOf"]) + field_type = Union[tuple(sub_models)] + elif "oneOf" in value: + sub_models = self._resolve_union_types(value["oneOf"]) + field_type = Union[tuple(sub_models)] + if "discriminator" in value: + discriminator = value["discriminator"]["propertyName"] + field_type = Annotated[field_type, Field(discriminator=discriminator)] + elif "enum" in value: + field_type = Literal[tuple(value["enum"])] + elif "allOf" in value: + merged = {"type": "object", "properties": {}, "required": []} + for s in value["allOf"]: + part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s + merged["properties"].update(part.get("properties", {})) + merged["required"].extend(part.get("required", [])) + for k, v in value.items(): + if k not in {"allOf", "properties", "required"}: + merged[k] = v + merged["required"] = list(set(merged["required"])) + field_type = self._json_schema_to_model(merged, f"{model_name}_{key}", root_schema) + elif value.get("type") == "object" and "properties" in value: + field_type = self._json_schema_to_model(value, f"{model_name}_{key}", root_schema) + else: + field_type = self._extract_field_type(key, value, model_name, root_schema) + + if field_type is None: + raise UnsupportedKeywordError(f"Unsupported or missing type for field `{key}` in `{model_name}`") + + default_value = value.get("default") + is_required = key in required_fields + + if not is_required and default_value is None: + field_type = Optional[field_type] + + field_args = { + "default": default_value if not is_required else ..., + } + if "title" in value: + field_args["title"] = value["title"] + if "description" in value: + field_args["description"] = value["description"] + + fields[key] = ( + field_type, + _make_field( + default_value if not is_required else ..., + title=value.get("title"), + description=value.get("description"), + ), + ) + + model: Type[BaseModel] = create_model(model_name, **cast(dict[str, Any], fields)) + model.model_rebuild() + return model + + +def schema_to_pydantic_model(schema: Dict[str, Any], model_name: str = "GeneratedModel") -> Type[BaseModel]: + """ + Convert a JSON Schema dictionary to a fully-typed Pydantic model. + + This function handles schema translation and validation logic to produce + a Pydantic model. + + **Supported JSON Schema Features** + + - **Primitive types**: `string`, `integer`, `number`, `boolean`, `object`, `array`, `null` + - **String formats**: + - `email`, `uri`, `uuid`, `uuid1`, `uuid3`, `uuid4`, `uuid5` + - `hostname`, `ipv4`, `ipv6`, `ipv4-network`, `ipv6-network` + - `date`, `time`, `date-time`, `duration` + - `byte`, `binary`, `password`, `path` + - **String constraints**: + - `minLength`, `maxLength`, `pattern` + - **Numeric constraints**: + - `minimum`, `maximum`, `exclusiveMinimum`, `exclusiveMaximum` + - **Array constraints**: + - `minItems`, `maxItems`, `items` + - **Object schema support**: + - `properties`, `required`, `title`, `description`, `default` + - **Enums**: + - Converted to Python `Literal` type + - **Union types**: + - `anyOf`, `oneOf` supported with optional `discriminator` + - **Inheritance and composition**: + - `allOf` merges multiple schemas into one model + - **$ref and $defs resolution**: + - Supports references to sibling definitions and self-referencing schemas + + .. code-block:: python + + from agentdhal_core.utils import schema_to_pydantic_model + + # Example 1: Simple user model + schema = { + "title": "User", + "type": "object", + "properties": { + "name": {"type": "string"}, + "email": {"type": "string", "format": "email"}, + "age": {"type": "integer", "minimum": 0}, + }, + "required": ["name", "email"], + } + + UserModel = schema_to_pydantic_model(schema) + user = UserModel(name="Alice", email="alice@example.com", age=30) + + .. code-block:: python + + from agentdhal_core.utils import schema_to_pydantic_model + + # Example 2: Nested model + schema = { + "title": "BlogPost", + "type": "object", + "properties": { + "title": {"type": "string"}, + "tags": {"type": "array", "items": {"type": "string"}}, + "author": { + "type": "object", + "properties": {"name": {"type": "string"}, "email": {"type": "string", "format": "email"}}, + "required": ["name"], + }, + }, + "required": ["title", "author"], + } + + BlogPost = schema_to_pydantic_model(schema) + + + .. code-block:: python + + from agentdhal_core.utils import schema_to_pydantic_model + + # Example 3: allOf merging with $refs + schema = { + "title": "EmployeeWithDepartment", + "allOf": [{"$ref": "#/$defs/Employee"}, {"$ref": "#/$defs/Department"}], + "$defs": { + "Employee": { + "type": "object", + "properties": {"id": {"type": "string"}, "name": {"type": "string"}}, + "required": ["id", "name"], + }, + "Department": { + "type": "object", + "properties": {"department": {"type": "string"}}, + "required": ["department"], + }, + }, + } + + Model = schema_to_pydantic_model(schema) + + .. code-block:: python + + from agentdhal_core.utils import schema_to_pydantic_model + + # Example 4: Self-referencing (recursive) model + schema = { + "title": "Category", + "type": "object", + "properties": { + "name": {"type": "string"}, + "subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}}, + }, + "required": ["name"], + "$defs": { + "Category": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}}, + }, + "required": ["name"], + } + }, + } + + Category = schema_to_pydantic_model(schema) + + .. code-block:: python + + # Example 5: Serializing and deserializing with Pydantic + + from uuid import uuid4 + from pydantic import BaseModel, EmailStr, Field + from typing import Optional, List, Dict, Any + from agentdhal_core.utils import schema_to_pydantic_model + + + class Address(BaseModel): + street: str + city: str + zipcode: str + + + class User(BaseModel): + id: str + name: str + email: EmailStr + age: int = Field(..., ge=18) + address: Address + + + class Employee(BaseModel): + id: str + name: str + manager: Optional["Employee"] = None + + + class Department(BaseModel): + name: str + employees: List[Employee] + + + class ComplexModel(BaseModel): + user: User + extra_info: Optional[Dict[str, Any]] = None + sub_items: List[Employee] + + + # Convert ComplexModel to JSON schema + complex_schema = ComplexModel.model_json_schema() + + # Rebuild a new Pydantic model from JSON schema + ReconstructedModel = schema_to_pydantic_model(complex_schema, "ComplexModel") + + # Instantiate reconstructed model + reconstructed = ReconstructedModel( + user={ + "id": str(uuid4()), + "name": "Alice", + "email": "alice@example.com", + "age": 30, + "address": {"street": "123 Main St", "city": "Wonderland", "zipcode": "12345"}, + }, + sub_items=[{"id": str(uuid4()), "name": "Bob", "manager": {"id": str(uuid4()), "name": "Eve"}}], + ) + + print(reconstructed.model_dump()) + + + Args: + schema (Dict[str, Any]): A valid JSON Schema dictionary. + model_name (str, optional): The name of the root model. Defaults to "GeneratedModel". + + Returns: + Type[BaseModel]: A dynamically generated Pydantic model class. + + Raises: + ReferenceNotFoundError: If a `$ref` key references a missing entry. + FormatNotSupportedError: If a `format` keyword is unknown or unsupported. + UnsupportedKeywordError: If the schema contains an unsupported `type`. + + See Also: + - :class:`pydantic.BaseModel` + - :func:`pydantic.create_model` + - https://json-schema.org/ + """ + ... + + return _JSONSchemaToPydantic().json_schema_to_pydantic(schema, model_name) diff --git a/agent_dhal/agentdhal_core/utils/_load_json.py b/agent_dhal/agentdhal_core/utils/_load_json.py new file mode 100644 index 0000000..95ccb0e --- /dev/null +++ b/agent_dhal/agentdhal_core/utils/_load_json.py @@ -0,0 +1,20 @@ +import json +import re +from typing import Any, Dict, List + + +def extract_json_from_str(content: str) -> List[Dict[str, Any]]: + """Extract JSON objects from a string. Supports backtick enclosed JSON objects""" + pattern = re.compile(r"```(?:\s*([\w\+\-]+))?\n([\s\S]*?)```") + matches = pattern.findall(content) + ret: List[Dict[str, Any]] = [] + # If no matches found, assume the entire content is a JSON object + if not matches: + ret.append(json.loads(content)) + for match in matches: + language = match[0].strip() if match[0] else None + if language and language.lower() != "json": + raise ValueError(f"Expected JSON object, but found language: {language}") + content = match[1] + ret.append(json.loads(content)) + return ret diff --git a/agent_dhal/agentdhal_extensions/__init__.py b/agent_dhal/agentdhal_extensions/__init__.py new file mode 100644 index 0000000..aadc80e --- /dev/null +++ b/agent_dhal/agentdhal_extensions/__init__.py @@ -0,0 +1,3 @@ +import importlib.metadata + +__version__ = importlib.metadata.version("agentdhal_extensions") diff --git a/agent_dhal/agentdhal_extensions/agents/azure/__init__.py b/agent_dhal/agentdhal_extensions/agents/azure/__init__.py new file mode 100644 index 0000000..ad60797 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/azure/__init__.py @@ -0,0 +1,10 @@ +try: + from ._azure_ai_agent import AzureAIAgent +except ImportError as e: + raise ImportError( + "Dependencies for AzureAIAgent not found. " + 'Please install autogen-ext with the "azure" extra: ' + 'pip install "agentdhal-ext[azure]"' + ) from e + +__all__ = ["AzureAIAgent"] diff --git a/agent_dhal/agentdhal_extensions/agents/azure/_azure_ai_agent.py b/agent_dhal/agentdhal_extensions/agents/azure/_azure_ai_agent.py new file mode 100644 index 0000000..ddc42e3 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/azure/_azure_ai_agent.py @@ -0,0 +1,1096 @@ +import asyncio +import json +import logging +import os +from typing import ( + Any, + AsyncGenerator, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Set, + cast, +) + +from agentdhal_agentchat import TRACE_LOGGER_NAME +from agentdhal_agentchat.agents import BaseChatAgent +from agentdhal_agentchat.base import Response +from agentdhal_agentchat.messages import ( + AgentEvent, + BaseChatMessage, + ChatMessage, + HandoffMessage, + MultiModalMessage, + StopMessage, + TextMessage, + ToolCallExecutionEvent, + ToolCallRequestEvent, +) +from agentdhal_core import CancellationToken, FunctionCall +from agentdhal_core.models._types import FunctionExecutionResult +from agentdhal_core.tools import FunctionTool, Tool + +from azure.ai.agents.models import ( + Agent, + AgentsResponseFormat, + AgentThread, + AzureAISearchToolDefinition, + AzureFunctionToolDefinition, + BingGroundingToolDefinition, + CodeInterpreterToolDefinition, + CodeInterpreterToolResource, + FileInfo, + FilePurpose, + FileSearchToolDefinition, + FileSearchToolResource, + FileState, + FunctionDefinition, + FunctionToolDefinition, + ListSortOrder, + MessageRole, + MessageTextUrlCitationAnnotation, + RunStatus, + ThreadRun, + ToolDefinition, + ToolOutput, + ToolResources, + VectorStore, + VectorStoreChunkingStrategyRequest, + VectorStoreDataSource, + VectorStoreExpirationPolicy, +) +from azure.ai.agents.models._patch import ThreadMessage +from azure.ai.projects.aio import AIProjectClient + +from ._types import AzureAIAgentState, ListToolType + +trace_logger = logging.getLogger(TRACE_LOGGER_NAME) + + +class AzureAIAgent(BaseChatAgent): + """ + Azure AI Assistant agent for AutoGen. + + Installation: + + .. code-block:: bash + + pip install "agentdhal-ext[azure]" # For Azure AI Foundry Agent Service + + This agent leverages the Azure AI Assistant API to create AI assistants with capabilities like: + + * Code interpretation and execution + * Grounding with Bing search + * File handling and search + * Custom function calling + * Multi-turn conversations + + The agent integrates with AutoGen's messaging system, providing a seamless way to use Azure AI + capabilities within the AutoGen framework. It supports tools like code interpreter, + file search, and various grounding mechanisms. + + Agent name must be a valid Python identifier: + 1. It must start with a letter (A-Z, a-z) or an underscore (_). + 2. It can only contain letters, digits (0-9), or underscores. + 3. It cannot be a Python keyword. + 4. It cannot contain spaces or special characters. + 5. It cannot start with a digit. + + + Check here on how to create a new secured agent with user-managed identity: + https://learn.microsoft.com/en-us/azure/ai-services/agents/how-to/virtual-networks + + Examples: + + Use the AzureAIAgent to create an agent grounded with Bing: + + .. code-block:: python + + import asyncio + import os + + from agentdhal_agentchat.messages import TextMessage + from agentdhal_core import CancellationToken + from agentdhal_extensions.agents.azure._azure_ai_agent import AzureAIAgent + from azure.ai.projects.aio import AIProjectClient + from azure.identity.aio import DefaultAzureCredential + from azure.ai.agents.models import BingGroundingTool + import dotenv + + + async def bing_example(): + async with DefaultAzureCredential() as credential: + async with AIProjectClient( # type: ignore + credential=credential, endpoint=os.getenv("AZURE_PROJECT_ENDPOINT", "") + ) as project_client: + conn = await project_client.connections.get(name=os.getenv("BING_CONNECTION_NAME", "")) + + bing_tool = BingGroundingTool(conn.id) + agent_with_bing_grounding = AzureAIAgent( + name="bing_agent", + description="An AI assistant with Bing grounding", + project_client=project_client, + deployment_name="gpt-4o", + instructions="You are a helpful assistant.", + tools=bing_tool.definitions, + metadata={"source": "AzureAIAgent"}, + ) + + # For the bing grounding tool to return the citations, the message must contain an instruction for the model to do return them. + # For example: "Please provide citations for the answers" + + result = await agent_with_bing_grounding.on_messages( + messages=[ + TextMessage( + content="What is Microsoft\\'s annual leave policy? Provide citations for your answers.", + source="user", + ) + ], + cancellation_token=CancellationToken(), + message_limit=5, + ) + print(result) + + + if __name__ == "__main__": + dotenv.load_dotenv() + asyncio.run(bing_example()) + + Use the AzureAIAgent to create an agent with file search capability: + + .. code-block:: python + + import asyncio + import os + import tempfile + import urllib.request + + import dotenv + from agentdhal_agentchat.messages import TextMessage + from agentdhal_core import CancellationToken + from agentdhal_extensions.agents.azure._azure_ai_agent import AzureAIAgent + from azure.ai.projects.aio import AIProjectClient + from azure.identity.aio import DefaultAzureCredential + + + async def file_search_example(): + # Download README.md from GitHub + readme_url = "https://raw.githubusercontent.com/microsoft/autogen/refs/heads/main/README.md" + temp_file = None + + try: + # Create a temporary file to store the downloaded README + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".md") + urllib.request.urlretrieve(readme_url, temp_file.name) + print(f"Downloaded README.md to {temp_file.name}") + + async with DefaultAzureCredential() as credential: + async with AIProjectClient( # type: ignore + credential=credential, endpoint=os.getenv("AZURE_PROJECT_ENDPOINT", "") + ) as project_client: + agent_with_file_search = AzureAIAgent( + name="file_search_agent", + description="An AI assistant with file search capabilities", + project_client=project_client, + deployment_name="gpt-4.1-mini", + instructions="You are a helpful assistant.", + tools=["file_search"], + metadata={"source": "AzureAIAgent"}, + ) + + ct: CancellationToken = CancellationToken() + # Use the downloaded README file for file search + await agent_with_file_search.on_upload_for_file_search( + file_paths=[temp_file.name], + vector_store_name="file_upload_index", + vector_store_metadata={"source": "AzureAIAgent"}, + cancellation_token=ct, + vector_store_polling_interval=60, + ) + result = await agent_with_file_search.on_messages( + messages=[ + TextMessage( + content="Hello, what is AutoGen and what capabilities does it have?", source="user" + ) + ], + cancellation_token=ct, + message_limit=5, + ) + print(result) + finally: + # Clean up the temporary file + if temp_file and os.path.exists(temp_file.name): + os.unlink(temp_file.name) + print(f"Removed temporary file {temp_file.name}") + + + if __name__ == "__main__": + dotenv.load_dotenv() + asyncio.run(file_search_example()) + + Use the AzureAIAgent to create an agent with code interpreter capability: + + .. code-block:: python + + import asyncio + import os + + import dotenv + from agentdhal_agentchat.messages import TextMessage + from agentdhal_core import CancellationToken + from agentdhal_extensions.agents.azure._azure_ai_agent import AzureAIAgent + from azure.ai.projects.aio import AIProjectClient + from azure.identity.aio import DefaultAzureCredential + + + async def code_interpreter_example(): + async with DefaultAzureCredential() as credential: + async with AIProjectClient( # type: ignore + credential=credential, endpoint=os.getenv("AZURE_PROJECT_ENDPOINT", "") + ) as project_client: + agent_with_code_interpreter = AzureAIAgent( + name="code_interpreter_agent", + description="An AI assistant with code interpreter capabilities", + project_client=project_client, + deployment_name="gpt-4.1-mini", + instructions="You are a helpful assistant.", + tools=["code_interpreter"], + metadata={"source": "AzureAIAgent"}, + ) + + await agent_with_code_interpreter.on_upload_for_code_interpreter( + file_paths="/workspaces/autogen/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/data/nifty_500_quarterly_results.csv", + cancellation_token=CancellationToken(), + polling_interval=5, + ) + + result = await agent_with_code_interpreter.on_messages( + messages=[ + TextMessage( + content="Aggregate the number of stocks per industry and give me a markdown table as a result?", + source="user", + ) + ], + cancellation_token=CancellationToken(), + ) + + print(result) + + + if __name__ == "__main__": + dotenv.load_dotenv() + asyncio.run(code_interpreter_example()) + """ + + def __init__( + self, + name: str, + description: str, + project_client: AIProjectClient, + deployment_name: str, + instructions: str, + tools: Optional[ListToolType] = None, + agent_id: Optional[str] = None, + thread_id: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + response_format: Optional[AgentsResponseFormat] = None, + temperature: Optional[float] = None, + tool_resources: Optional[ToolResources] = None, + top_p: Optional[float] = None, + ) -> None: + """ + Initialize the Azure AI Agent. + + Args: + name (str): The name of the agent. Must be a valid Python identifier. + description (str): A brief description of the agent's purpose. + project_client (AIProjectClient): The Azure AI Project client for API interactions. + deployment_name (str): The model deployment name to use for the agent (e.g., "gpt-4"). + instructions (str): Detailed instructions for the agent's behavior. + tools (Optional[Iterable[Union[str, ToolDefinition, Tool, Callable]]]): A list of tools the agent can use. + Supported string values: "file_search", "code_interpreter", "bing_grounding", + "azure_ai_search", "azure_function", "sharepoint_grounding". + agent_id (Optional[str]): Existing agent ID to use instead of creating a new one. + thread_id (Optional[str]): Existing thread ID to continue a conversation. + metadata (Optional[Dict[str, str]]): Additional metadata for the agent. + response_format (Optional[_types.AgentsApiResponseFormatOption]): Format options for the agent's responses. + temperature (Optional[float]): Sampling temperature, controls randomness of output. + tool_resources (Optional[models.ToolResources]): Resources configuration for agent tools. + top_p (Optional[float]): An alternative to temperature, nucleus sampling parameter. + + Raises: + ValueError: If an unsupported tool type is provided. + """ + super().__init__(name, description) + + if tools is None: + tools = [] + + self._original_tools: list[Tool] = [] + + converted_tools: List[ToolDefinition] = [] + self._add_tools(tools, converted_tools) + + self._project_client = project_client + self._agent: Optional[Agent] = None + self._thread: Optional[AgentThread] = None + self._init_thread_id = thread_id + self._deployment_name = deployment_name + self._instructions = instructions + self._api_tools = converted_tools + self._agent_id = agent_id + self._metadata = metadata + self._response_format = response_format + self._temperature = temperature + self._tool_resources = tool_resources + self._top_p = top_p + self._vector_store_id: Optional[str] = None + self._uploaded_file_ids: List[str] = [] + + self._initial_message_ids: Set[str] = set() + self._initial_state_retrieved: bool = False + + # Properties + @property + def produced_message_types(self) -> Sequence[type[ChatMessage]]: + """The types of messages that the assistant agent produces.""" + return (TextMessage,) + + @property + def thread_id(self) -> str: + if self._thread is None: + raise ValueError("Thread not initialized") + return self._thread.id + + @property + def _get_agent_id(self) -> str: + if self._agent is None: + raise ValueError("Agent not initialized") + return self._agent.id + + @property + def description(self) -> str: + if not self._description: + raise ValueError("Description not initialized") + return self._description + + @property + def agent_id(self) -> str: + if not self._agent_id: + raise ValueError("Agent not initialized") + return self._agent_id + + @property + def deployment_name(self) -> str: + if not self._deployment_name: + raise ValueError("Deployment name not initialized") + return self._deployment_name + + @property + def instructions(self) -> str: + if not self._instructions: + raise ValueError("Instructions not initialized") + return self._instructions + + @property + def tools(self) -> List[ToolDefinition]: + """ + Get the list of tools available to the agent. + + Returns: + List[ToolDefinition]: The list of tool definitions. + """ + return self._api_tools + + def _add_tools(self, tools: Optional[ListToolType], converted_tools: List[ToolDefinition]) -> None: + """ + Convert various tool formats to Azure AI Agent tool definitions. + + Args: + tools: List of tools in various formats (string identifiers, ToolDefinition objects, Tool objects, or callables) + converted_tools: List to which converted tool definitions will be added + + Raises: + ValueError: If an unsupported tool type is provided + """ + if tools is None: + return + + for tool in tools: + if isinstance(tool, str): + if tool == "file_search": + converted_tools.append(FileSearchToolDefinition()) + elif tool == "code_interpreter": + converted_tools.append(CodeInterpreterToolDefinition()) + elif tool == "bing_grounding": + converted_tools.append(BingGroundingToolDefinition()) # type: ignore + elif tool == "azure_ai_search": + converted_tools.append(AzureAISearchToolDefinition()) + elif tool == "azure_function": + converted_tools.append(AzureFunctionToolDefinition()) # type: ignore + # elif tool == "sharepoint_grounding": + # converted_tools.append(SharepointToolDefinition()) # type: ignore + else: + raise ValueError(f"Unsupported tool string: {tool}") + elif isinstance(tool, ToolDefinition): + converted_tools.append(tool) + elif isinstance(tool, Tool): + self._original_tools.append(tool) + converted_tools.append(self._convert_tool_to_function_tool_definition(tool)) + elif callable(tool): + if hasattr(tool, "__doc__") and tool.__doc__ is not None: + description = tool.__doc__ + else: + description = "" + function_tool = FunctionTool(tool, description=description) + self._original_tools.append(function_tool) + converted_tools.append(self._convert_tool_to_function_tool_definition(function_tool)) + else: + raise ValueError(f"Unsupported tool type: {type(tool)}") + + def _convert_tool_to_function_tool_definition(self, tool: Tool) -> FunctionToolDefinition: + """ + Convert an autogen Tool to an Azure AI Agent function tool definition. + + Args: + tool (Tool): The AutoGen tool to convert + + Returns: + models.FunctionToolDefinition: A function tool definition compatible with Azure AI Agent API + """ + + schema = tool.schema + parameters: Dict[str, object] = {} + + if "parameters" in schema: + parameters = { + "type": schema["parameters"]["type"], + "properties": schema["parameters"]["properties"], + } + if "required" in schema["parameters"]: + parameters["required"] = schema["parameters"]["required"] + + func_definition = FunctionDefinition(name=tool.name, description=tool.description, parameters=parameters) + + return FunctionToolDefinition( + function=func_definition, + ) + + async def _ensure_initialized(self, create_new_thread: bool = False, create_new_agent: bool = False) -> None: + """ + Ensure agent and thread are properly initialized before operations. + + This method ensures that both the Azure AI Agent and thread are created or retrieved + from existing IDs. It also handles retrieving the initial state of an existing thread + when needed. + + Args: + create_new_thread (bool): When True, creates a new thread even if thread_id is provided + create_new_agent (bool): When True, creates a new agent even if agent_id is provided + + Raises: + ValueError: If agent or thread creation fails + """ + if self._agent is None or create_new_agent: + if self._agent_id and create_new_agent is False: + self._agent = await self._project_client.agents.get_agent(agent_id=self._agent_id) + else: + self._agent = await self._project_client.agents.create_agent( + name=self.name, + model=self._deployment_name, + description=self.description, + instructions=self._instructions, + tools=self._api_tools, + metadata=self._metadata, + response_format=self._response_format if self._response_format else None, # type: ignore + temperature=self._temperature, + tool_resources=self._tool_resources if self._tool_resources else None, # type: ignore + top_p=self._top_p, + ) + + if self._thread is None or create_new_thread: + if self._init_thread_id and create_new_thread is False: + self._thread = await self._project_client.agents.threads.get(thread_id=self._init_thread_id) + # Retrieve initial state only once + if not self._initial_state_retrieved: + await self._retrieve_initial_state() + self._initial_state_retrieved = True + else: + self._thread = await self._project_client.agents.threads.create() + + async def _retrieve_initial_state(self) -> None: + """ + Retrieve and store the initial state of messages in the thread. + + This method retrieves all message IDs from an existing thread to track which + messages were present before this agent instance started interacting with the thread. + It handles pagination to ensure all messages are captured. + """ + # Retrieve all initial message IDs + initial_message_ids: Set[str] = set() + async for msg in self._project_client.agents.messages.list( + thread_id=self.thread_id, + order=ListSortOrder.ASCENDING, + limit=100, + ): + initial_message_ids.add(msg.id) + self._initial_message_ids = initial_message_ids + + async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str: + """ + Execute a tool call requested by the Azure AI agent. + + Args: + tool_call (FunctionCall): The function call information including name and arguments + cancellation_token (CancellationToken): Token for cancellation handling + + Returns: + str: The string representation of the tool call result + + Raises: + ValueError: If the requested tool is not available or no tools are registered + """ + if not self._original_tools: + raise ValueError("No tools are available.") + tool = next((t for t in self._original_tools if t.name == tool_call.name), None) + if tool is None: + raise ValueError(f"The tool '{tool_call.name}' is not available.") + arguments = json.loads(tool_call.arguments) + result = await tool.run_json(arguments, cancellation_token, call_id=tool_call.id) + return tool.return_value_as_string(result) + + async def _upload_files( + self, + file_paths: str | Iterable[str], + purpose: str = "assistant", + polling_interval: float = 0.5, + cancellation_token: Optional[CancellationToken] = None, + ) -> List[str]: + """ + Upload files to the Azure AI Assistant API. + + This method handles uploading one or more files to be used by the agent + and tracks their IDs in the agent's state. + + Args: + file_paths (str | Iterable[str]): Path(s) to file(s) to upload + purpose (str): The purpose of the file, defaults to "assistant" + polling_interval (float): Time to sleep between polling for file status + cancellation_token (Optional[CancellationToken]): Token for cancellation handling + + Returns: + List[str]: List of file IDs for the uploaded files + + Raises: + ValueError: If file upload fails + """ + if cancellation_token is None: + cancellation_token = CancellationToken() + + await self._ensure_initialized() + + if isinstance(file_paths, str): + file_paths = [file_paths] + + file_ids: List[str] = [] + for file_path in file_paths: + file_name = os.path.basename(file_path) + + file: FileInfo = await cancellation_token.link_future( + asyncio.ensure_future( + self._project_client.agents.files.upload_and_poll( + file_path=file_path, purpose=purpose, polling_interval=polling_interval + ) + ) + ) + + if file.status != FileState.PROCESSED: + raise ValueError(f"File upload failed with status {file.status}") + + trace_logger.debug(f"File uploaded successfully: {file.id}, {file_name}") + + file_ids.append(file.id) + self._uploaded_file_ids.append(file.id) + + return file_ids + + # Public Methods + async def on_messages( + self, + messages: Sequence[BaseChatMessage], + cancellation_token: Optional[CancellationToken] = None, + message_limit: int = 1, + ) -> Response: + """ + Process incoming messages and return a response from the Azure AI agent. + + This method is the primary entry point for interaction with the agent. + It delegates to on_messages_stream and returns the final response. + + Args: + messages (Sequence[BaseChatMessage]): The messages to process + cancellation_token (CancellationToken): Token for cancellation handling + message_limit (int, optional): Maximum number of messages to retrieve from the thread + + Returns: + Response: The agent's response, including the chat message and any inner events + + Raises: + AssertionError: If the stream doesn't return a final result + """ + async for message in self.on_messages_stream( + messages=messages, cancellation_token=cancellation_token, message_limit=message_limit + ): + if isinstance(message, Response): + return message + raise AssertionError("The stream should have returned the final result.") + + async def on_messages_stream( + self, + messages: Sequence[BaseChatMessage], + cancellation_token: Optional[CancellationToken] = None, + message_limit: int = 1, + polling_interval: float = 0.5, + ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]: + """ + Process incoming messages and yield streaming responses from the Azure AI agent. + + This method handles the complete interaction flow with the Azure AI agent: + 1. Processing input messages + 2. Creating and monitoring a run + 3. Handling tool calls and their results + 4. Retrieving and returning the agent's final response + + The method yields events during processing (like tool calls) and finally yields + the complete Response with the agent's message. + + Args: + messages (Sequence[BaseChatMessage]): The messages to process + cancellation_token (CancellationToken): Token for cancellation handling + message_limit (int, optional): Maximum number of messages to retrieve from the thread + polling_interval (float, optional): Time to sleep between polling for run status + + Yields: + AgentEvent | ChatMessage | Response: Events during processing and the final response + + Raises: + ValueError: If the run fails or no message is received from the assistant + """ + if cancellation_token is None: + cancellation_token = CancellationToken() + + await self._ensure_initialized() + + # Process all messages in sequence + for message in messages: + if isinstance(message, (TextMessage, MultiModalMessage)): + await self.handle_text_message(str(message.content), cancellation_token) + elif isinstance(message, (StopMessage, HandoffMessage)): + await self.handle_text_message(message.content, cancellation_token) + + # Inner messages for tool calls + inner_messages: List[AgentEvent | ChatMessage] = [] + + # Create and start a run + run: ThreadRun = await cancellation_token.link_future( + asyncio.ensure_future( + self._project_client.agents.runs.create( + thread_id=self.thread_id, + agent_id=self._get_agent_id, + ) + ) + ) + + # Wait for run completion by polling + while True: + run = await cancellation_token.link_future( + asyncio.ensure_future( + self._project_client.agents.runs.get( + thread_id=self.thread_id, + run_id=run.id, + ) + ) + ) + + if run.status == RunStatus.FAILED: + raise ValueError(f"Run failed: {run.last_error}") + + # If the run requires action (function calls), execute tools and continue + if run.status == RunStatus.REQUIRES_ACTION and run.required_action is not None: + tool_calls: List[FunctionCall] = [] + submit_tool_outputs = getattr(run.required_action, "submit_tool_outputs", None) + if submit_tool_outputs and hasattr(submit_tool_outputs, "tool_calls"): + for required_tool_call in submit_tool_outputs.tool_calls: + if required_tool_call.type == "function": + tool_calls.append( + FunctionCall( + id=required_tool_call.id, + name=required_tool_call.function.name, + arguments=required_tool_call.function.arguments, + ) + ) + + # Add tool call message to inner messages + tool_call_msg = ToolCallRequestEvent(source=self.name, content=tool_calls) + inner_messages.append(tool_call_msg) + trace_logger.debug(tool_call_msg) + yield tool_call_msg + + # Execute tool calls and get results + tool_outputs: List[FunctionExecutionResult] = [] + + # TODO: Support parallel execution of tool calls + + for tool_call in tool_calls: + try: + result = await self._execute_tool_call(tool_call, cancellation_token) + is_error = False + except Exception as e: + result = f"Error: {e}" + is_error = True + tool_outputs.append( + FunctionExecutionResult( + content=result, call_id=tool_call.id, is_error=is_error, name=tool_call.name + ) + ) + + # Add tool result message to inner messages + tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs) + inner_messages.append(tool_result_msg) + trace_logger.debug(tool_result_msg) + yield tool_result_msg + + # Submit tool outputs back to the run + run = await cancellation_token.link_future( + asyncio.ensure_future( + self._project_client.agents.runs.submit_tool_outputs( + thread_id=self.thread_id, + run_id=run.id, + tool_outputs=[ToolOutput(tool_call_id=t.call_id, output=t.content) for t in tool_outputs], + ) + ) + ) + continue + + if run.status == RunStatus.COMPLETED: + break + + # TODO support for parameter to control polling interval + await asyncio.sleep(polling_interval) + + # After run is completed, get the messages + trace_logger.debug("Retrieving messages from thread") + # Collect up to message_limit messages in DESCENDING order, support cancellation + agent_messages: List[ThreadMessage] = [] + async for msg in self._project_client.agents.messages.list( + thread_id=self.thread_id, + order=ListSortOrder.DESCENDING, + limit=message_limit, + ): + if cancellation_token.is_cancelled(): + trace_logger.debug("Message retrieval cancelled by token.") + break + agent_messages.append(msg) + if len(agent_messages) >= message_limit: + break + if not agent_messages: + raise ValueError("No messages received from assistant") + + # Get the last message from the agent (role=AGENT) + last_message: Optional[ThreadMessage] = next( + (m for m in agent_messages if getattr(m, "role", None) == "agent"), None + ) + if not last_message: + trace_logger.debug("No message with AGENT role found, falling back to first message") + last_message = agent_messages[0] # Fallback to first message + if not getattr(last_message, "content", None): + raise ValueError("No content in the last message") + + # Extract text content + message_text = "" + for text_message in last_message.text_messages: + message_text += text_message.text.value + + # Extract citations + citations: list[Any] = [] + + # Try accessing annotations directly + + annotations = getattr(last_message, "annotations", []) + + if isinstance(annotations, list) and annotations: + annotations = cast(List[MessageTextUrlCitationAnnotation], annotations) + + trace_logger.debug(f"Found {len(annotations)} annotations") + for annotation in annotations: + if hasattr(annotation, "url_citation"): # type: ignore + trace_logger.debug(f"Citation found: {annotation.url_citation.url}") + citations.append( + {"url": annotation.url_citation.url, "title": annotation.url_citation.title, "text": None} # type: ignore + ) + # For backwards compatibility + elif hasattr(last_message, "url_citation_annotations") and last_message.url_citation_annotations: + url_annotations = cast(List[Any], last_message.url_citation_annotations) + + trace_logger.debug(f"Found {len(url_annotations)} URL citations") + + for annotation in url_annotations: + citations.append( + {"url": annotation.url_citation.url, "title": annotation.url_citation.title, "text": None} # type: ignore + ) + + elif hasattr(last_message, "file_citation_annotations") and last_message.file_citation_annotations: + file_annotations = cast(List[Any], last_message.file_citation_annotations) + + trace_logger.debug(f"Found {len(file_annotations)} URL citations") + + for annotation in file_annotations: + citations.append( + {"file_id": annotation.file_citation.file_id, "title": None, "text": annotation.file_citation.quote} # type: ignore + ) + + trace_logger.debug(f"Total citations extracted: {len(citations)}") + + # Create the response message with citations as JSON string + chat_message = TextMessage( + source=self.name, content=message_text, metadata={"citations": json.dumps(citations)} if citations else {} + ) + + # Return the assistant's response as a Response with inner messages + yield Response(chat_message=chat_message, inner_messages=inner_messages) + + async def handle_text_message(self, content: str, cancellation_token: Optional[CancellationToken] = None) -> None: + """ + Handle a text message by adding it to the conversation thread. + + Args: + content (str): The text content of the message + cancellation_token (CancellationToken): Token for cancellation handling + + Returns: + None + """ + + if cancellation_token is None: + cancellation_token = CancellationToken() + + await cancellation_token.link_future( + asyncio.ensure_future( + self._project_client.agents.messages.create( + thread_id=self.thread_id, + role=MessageRole.USER, + content=content, + ) + ) + ) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + """ + Reset the agent's conversation by creating a new thread. + + This method allows for resetting a conversation without losing the agent + definition or capabilities. It creates a new thread for fresh conversations. + + Note: Currently the Azure AI Agent API has no support for deleting messages, + so a new thread is created instead. + + Args: + cancellation_token (CancellationToken): Token for cancellation handling + """ + # This will enforce the creation of a new thread + await self._ensure_initialized(create_new_thread=True) + + async def save_state(self) -> Mapping[str, Any]: + """ + Save the current state of the agent for future restoration. + + This method serializes the agent's state including IDs for the agent, thread, + messages, and associated resources like vector stores and uploaded files. + + Returns: + Mapping[str, Any]: A dictionary containing the serialized state data + """ + state = AzureAIAgentState( + agent_id=self._agent.id if self._agent else self._agent_id, + thread_id=self._thread.id if self._thread else self._init_thread_id, + initial_message_ids=list(self._initial_message_ids), + vector_store_id=self._vector_store_id, + uploaded_file_ids=self._uploaded_file_ids, + ) + return state.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + """ + Load a previously saved state into this agent. + + This method deserializes and restores a previously saved agent state, + setting up the agent to continue a previous conversation or session. + + Args: + state (Mapping[str, Any]): The previously saved state dictionary + """ + agent_state = AzureAIAgentState.model_validate(state) + self._agent_id = agent_state.agent_id + self._init_thread_id = agent_state.thread_id + self._initial_message_ids = set(agent_state.initial_message_ids) + self._vector_store_id = agent_state.vector_store_id + self._uploaded_file_ids = agent_state.uploaded_file_ids + + async def on_upload_for_code_interpreter( + self, + file_paths: str | Iterable[str], + cancellation_token: Optional[CancellationToken] = None, + polling_interval: float = 0.5, + ) -> None: + """ + Upload files to be used with the code interpreter tool. + + This method uploads files for the agent's code interpreter tool and + updates the thread's tool resources to include these files. + + Args: + file_paths (str | Iterable[str]): Path(s) to file(s) to upload + cancellation_token (Optional[CancellationToken]): Token for cancellation handling + polling_interval (float): Time to sleep between polling for file status + + Raises: + ValueError: If file upload fails or the agent doesn't have code interpreter capability + """ + if cancellation_token is None: + cancellation_token = CancellationToken() + + await self._ensure_initialized() + + file_ids = await self._upload_files( + file_paths=file_paths, + cancellation_token=cancellation_token, + polling_interval=polling_interval, + purpose=FilePurpose.AGENTS, + ) + + # Update thread with the new files + thread: AgentThread = await cancellation_token.link_future( + asyncio.ensure_future(self._project_client.agents.threads.get(thread_id=self.thread_id)) + ) + + tool_resources: ToolResources = thread.tool_resources or ToolResources() + code_interpreter_resource = tool_resources.code_interpreter or CodeInterpreterToolResource() + existing_file_ids: List[str] = code_interpreter_resource.file_ids or [] + existing_file_ids.extend(file_ids) + + await cancellation_token.link_future( + asyncio.ensure_future( + self._project_client.agents.threads.update( + thread_id=self.thread_id, + tool_resources=ToolResources( + code_interpreter=CodeInterpreterToolResource(file_ids=existing_file_ids) + ), + ) + ) + ) + + async def on_upload_for_file_search( + self, + file_paths: str | Iterable[str], + cancellation_token: CancellationToken, + vector_store_name: Optional[str] = None, + data_sources: Optional[List[VectorStoreDataSource]] = None, + expires_after: Optional[VectorStoreExpirationPolicy] = None, + chunking_strategy: Optional[VectorStoreChunkingStrategyRequest] = None, + vector_store_metadata: Optional[Dict[str, str]] = None, + vector_store_polling_interval: float = 1, + ) -> None: + """ + Upload files to be used with the file search tool. + + This method handles uploading files for the file search capability, creating a vector + store if necessary, and updating the agent's configuration to use the vector store. + + Args: + file_paths (str | Iterable[str]): Path(s) to file(s) to upload + cancellation_token (CancellationToken): Token for cancellation handling + vector_store_name (Optional[str]): Name to assign to the vector store if creating a new one + data_sources (Optional[List[VectorStoreDataSource]]): Additional data sources for the vector store + expires_after (Optional[VectorStoreExpirationPolicy]): Expiration policy for vector store content + chunking_strategy (Optional[VectorStoreChunkingStrategyRequest]): Strategy for chunking file content + vector_store_metadata (Optional[Dict[str, str]]): Additional metadata for the vector store + vector_store_polling_interval (float): Time to sleep between polling for vector store status + + Raises: + ValueError: If file search is not enabled for this agent or file upload fails + """ + await self._ensure_initialized() + + # Check if file_search is enabled in tools + if not any(isinstance(tool, FileSearchToolDefinition) for tool in self._api_tools): + raise ValueError( + "File search is not enabled for this assistant. Add a file_search tool when creating the assistant." + ) + + # Create vector store if not already created + if self._vector_store_id is None: + vector_store: VectorStore = await cancellation_token.link_future( + asyncio.ensure_future( + self._project_client.agents.vector_stores.create_and_poll( + file_ids=[], + name=vector_store_name, + data_sources=data_sources, + expires_after=expires_after, + chunking_strategy=chunking_strategy, + metadata=vector_store_metadata, + polling_interval=vector_store_polling_interval, + ) + ) + ) + self._vector_store_id = vector_store.id + + # Update assistant with vector store ID + await cancellation_token.link_future( + asyncio.ensure_future( + self._project_client.agents.update_agent( + agent_id=self._get_agent_id, + tools=self._api_tools, + tool_resources=ToolResources( + file_search=FileSearchToolResource(vector_store_ids=[self._vector_store_id]) + ), + ) + ) + ) + + file_ids = await self._upload_files( + file_paths=file_paths, cancellation_token=cancellation_token, purpose=FilePurpose.AGENTS + ) + + # Create file batch with the file IDs + await cancellation_token.link_future( + asyncio.ensure_future( + self._project_client.agents.vector_store_file_batches.create_and_poll( + vector_store_id=self._vector_store_id, + file_ids=file_ids, + polling_interval=vector_store_polling_interval, + ) + ) + ) + + async def close(self) -> None: + """ + Close the Azure AI agent and release any resources. + """ + await self._project_client.close() + + +if __name__ == "__main__": + # Example usage of AzureAIAgent + # Replace with your actual endpoint and credentials + """ + TODO: + [X] Support for file upload + [] Support for sharepoint grounding + [] Support for azure function grounding + [X] Support for file search + [X] Support for custom function calling + [X] Add metadata to the thread (agent_id, source ="AUTODGEN_AGENT") + """ diff --git a/agent_dhal/agentdhal_extensions/agents/azure/_types.py b/agent_dhal/agentdhal_extensions/agents/azure/_types.py new file mode 100644 index 0000000..e0230fb --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/azure/_types.py @@ -0,0 +1,61 @@ +from typing import Any, Awaitable, Callable, Iterable, List, Literal, Optional, TypeGuard, Union + +from agentdhal_core.tools import Tool +from pydantic import BaseModel, Field + +from azure.ai.agents.models import ( + AzureAISearchToolDefinition, + AzureFunctionToolDefinition, + BingGroundingToolDefinition, + CodeInterpreterToolDefinition, + FileSearchToolDefinition, + MessageTextUrlCitationAnnotation, +) + +ListToolType = Iterable[ + Union[ + Literal[ + "file_search", + "code_interpreter", + "bing_grounding", + "azure_ai_search", + "azure_function", + ], + BingGroundingToolDefinition, + CodeInterpreterToolDefinition, + AzureAISearchToolDefinition, + FileSearchToolDefinition, + AzureFunctionToolDefinition, + Tool, + Callable[..., Any], + Callable[..., Awaitable[Any]], + ] +] + + +class AzureAIAgentState(BaseModel): + """ + Represents the state of an AzureAIAgent that can be saved and loaded. + + This state model keeps track of persistent information about an agent session + including agent and thread identifiers, message history, and associated resources. + + Attributes: + type (str): The type identifier for the state object, always "AzureAIAgentState" + agent_id (Optional[str]): The ID of the Azure AI agent + thread_id (Optional[str]): The ID of the conversation thread + initial_message_ids (List[str]): List of message IDs from the initial state + vector_store_id (Optional[str]): The ID of the associated vector store for file search + uploaded_file_ids (List[str]): List of IDs for files uploaded to the agent + """ + + type: str = Field(default="AzureAIAgentState") + agent_id: Optional[str] = None + thread_id: Optional[str] = None + initial_message_ids: List[str] = Field(default_factory=list) + vector_store_id: Optional[str] = None + uploaded_file_ids: List[str] = Field(default_factory=list) + + +def has_annotations(obj: Any) -> TypeGuard[list[MessageTextUrlCitationAnnotation]]: + return obj is not None and isinstance(obj, list) diff --git a/agent_dhal/agentdhal_extensions/agents/file_surfer/__init__.py b/agent_dhal/agentdhal_extensions/agents/file_surfer/__init__.py new file mode 100644 index 0000000..79d5ba2 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/file_surfer/__init__.py @@ -0,0 +1,3 @@ +from ._file_surfer import FileSurfer + +__all__ = ["FileSurfer"] diff --git a/agent_dhal/agentdhal_extensions/agents/file_surfer/_file_surfer.py b/agent_dhal/agentdhal_extensions/agents/file_surfer/_file_surfer.py new file mode 100644 index 0000000..f55f918 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/file_surfer/_file_surfer.py @@ -0,0 +1,208 @@ +import json +import os +import traceback +from typing import List, Sequence, Tuple + +from agentdhal_agentchat.agents import BaseChatAgent +from agentdhal_agentchat.base import Response +from agentdhal_agentchat.messages import ( + BaseChatMessage, + TextMessage, +) +from agentdhal_agentchat.utils import remove_images +from agentdhal_core import CancellationToken, Component, ComponentModel, FunctionCall +from agentdhal_core.models import ( + AssistantMessage, + ChatCompletionClient, + LLMMessage, + SystemMessage, + UserMessage, +) +from pydantic import BaseModel +from typing_extensions import Self + +from ._markdown_file_browser import MarkdownFileBrowser + +# from typing_extensions import Annotated +from ._tool_definitions import ( + TOOL_FIND_NEXT, + TOOL_FIND_ON_PAGE_CTRL_F, + TOOL_OPEN_PATH, + TOOL_PAGE_DOWN, + TOOL_PAGE_UP, +) + + +class FileSurferConfig(BaseModel): + """Configuration for FileSurfer agent""" + + name: str + model_client: ComponentModel + description: str | None = None + + +class FileSurfer(BaseChatAgent, Component[FileSurferConfig]): + """An agent, used by MagenticOne, that acts as a local file previewer. FileSurfer can open and read a variety of common file types, and can navigate the local file hierarchy. + + Installation: + + .. code-block:: bash + + pip install "agentdhal-ext[file-surfer]" + + Args: + name (str): The agent's name + model_client (ChatCompletionClient): The model to use (must be tool-use enabled) + description (str): The agent's description used by the team. Defaults to DEFAULT_DESCRIPTION + base_path (str): The base path to use for the file browser. Defaults to the current working directory. + + """ + + component_config_schema = FileSurferConfig + component_provider_override = "agentdhal_extensions.agents.file_surfer.FileSurfer" + + DEFAULT_DESCRIPTION = "An agent that can handle local files." + + DEFAULT_SYSTEM_MESSAGES = [ + SystemMessage( + content=""" + You are a helpful AI Assistant. + When given a user query, use available functions to help the user with their request.""" + ), + ] + + def __init__( + self, + name: str, + model_client: ChatCompletionClient, + description: str = DEFAULT_DESCRIPTION, + base_path: str = os.getcwd(), + ) -> None: + super().__init__(name, description) + self._model_client = model_client + self._chat_history: List[LLMMessage] = [] + self._browser = MarkdownFileBrowser(viewport_size=1024 * 5, base_path=base_path) + + @property + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + return (TextMessage,) + + async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: + for chat_message in messages: + self._chat_history.append(chat_message.to_model_message()) + try: + _, content = await self._generate_reply(cancellation_token=cancellation_token) + self._chat_history.append(AssistantMessage(content=content, source=self.name)) + return Response(chat_message=TextMessage(content=content, source=self.name)) + + except BaseException: + content = f"File surfing error:\n\n{traceback.format_exc()}" + self._chat_history.append(AssistantMessage(content=content, source=self.name)) + return Response(chat_message=TextMessage(content=content, source=self.name)) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + self._chat_history.clear() + + def _get_browser_state(self) -> Tuple[str, str]: + """ + Get the current state of the browser, including the header and content. + """ + header = f"Path: {self._browser.path}\n" + + if self._browser.page_title is not None: + header += f"Title: {self._browser.page_title}\n" + + current_page = self._browser.viewport_current_page + total_pages = len(self._browser.viewport_pages) + header += f"Viewport position: Showing page {current_page+1} of {total_pages}.\n" + + return (header, self._browser.viewport) + + async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, str]: + history = self._chat_history[0:-1] + last_message = self._chat_history[-1] + assert isinstance(last_message, UserMessage) + + task_content = last_message.content # the last message from the sender is the task + + assert self._browser is not None + + context_message = UserMessage( + source="user", + content=f"Your file viewer is currently open to the file or directory '{self._browser.page_title}' with path '{self._browser.path}'.", + ) + + task_message = UserMessage( + source="user", + content=task_content, + ) + + create_result = await self._model_client.create( + messages=self._get_compatible_context(history + [context_message, task_message]), + tools=[ + TOOL_OPEN_PATH, + TOOL_PAGE_DOWN, + TOOL_PAGE_UP, + TOOL_FIND_NEXT, + TOOL_FIND_ON_PAGE_CTRL_F, + ], + cancellation_token=cancellation_token, + ) + + response = create_result.content + + if isinstance(response, str): + # Answer directly. + return False, response + + elif isinstance(response, list) and all(isinstance(item, FunctionCall) for item in response): + function_calls = response + for function_call in function_calls: + tool_name = function_call.name + + try: + arguments = json.loads(function_call.arguments) + except json.JSONDecodeError as e: + error_str = f"File surfer encountered an error decoding JSON arguments: {e}" + return False, error_str + + if tool_name == "open_path": + path = arguments["path"] + self._browser.open_path(path) + elif tool_name == "page_up": + self._browser.page_up() + elif tool_name == "page_down": + self._browser.page_down() + elif tool_name == "find_on_page_ctrl_f": + search_string = arguments["search_string"] + self._browser.find_on_page(search_string) + elif tool_name == "find_next": + self._browser.find_next() + header, content = self._get_browser_state() + final_response = header.strip() + "\n=======================\n" + content + return False, final_response + + final_response = "TERMINATE" + return False, final_response + + def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]: + """Ensure that the messages are compatible with the underlying client, by removing images if needed.""" + if self._model_client.model_info["vision"]: + return messages + else: + return remove_images(messages) + + def _to_config(self) -> FileSurferConfig: + return FileSurferConfig( + name=self.name, + model_client=self._model_client.dump_component(), + description=self.description, + ) + + @classmethod + def _from_config(cls, config: FileSurferConfig) -> Self: + return cls( + name=config.name, + model_client=ChatCompletionClient.load_component(config.model_client), + description=config.description or cls.DEFAULT_DESCRIPTION, + ) diff --git a/agent_dhal/agentdhal_extensions/agents/file_surfer/_markdown_file_browser.py b/agent_dhal/agentdhal_extensions/agents/file_surfer/_markdown_file_browser.py new file mode 100644 index 0000000..93d6932 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/file_surfer/_markdown_file_browser.py @@ -0,0 +1,317 @@ +# ruff: noqa: E722 +import datetime +import io +import os +import re +import time +from typing import List, Optional, Tuple, Union + +# TODO: Fix unfollowed import +from markitdown import FileConversionException, MarkItDown, UnsupportedFormatException # type: ignore + + +class MarkdownFileBrowser: + """ + (In preview) An extremely simple Markdown-powered file browser. + """ + + # TODO: Fix unfollowed import + def __init__( # type: ignore + self, + viewport_size: Union[int, None] = 1024 * 8, + base_path: str | None = os.getcwd(), + cwd: str | None = None, + ): + """ + Instantiate a new MarkdownFileBrowser. + + Arguments: + viewport_size: Approximately how many *characters* fit in the viewport. Viewport dimensions are adjusted dynamically to avoid cutting off words (default: 8192). + base_path: The base path to use for the file browser. Files outside this path cannot be accessed. Defaults to the current working directory. + cwd: The browser's current working directory. Defaults to the system's current working directory. + """ + self.viewport_size = viewport_size # Applies only to the standard uri types + self.history: List[Tuple[str, float]] = list() + self.page_title: Optional[str] = None + self.viewport_current_page = 0 + self.viewport_pages: List[Tuple[int, int]] = list() + self._markdown_converter = MarkItDown() + self._base_path = None if base_path is None else os.path.realpath(base_path) + self._page_content: str = "" + self._find_on_page_query: Union[str, None] = None + self._find_on_page_last_result: Union[int, None] = None # Location of the last result + + # Set the working directory + if cwd is None: + if self._validate_path(os.getcwd()): + # Use the current working directory if it's in the base path + cwd = os.path.realpath(os.getcwd()) + elif self._base_path is not None: + # Otherwise, use the base path + cwd = os.path.realpath(self._base_path) + else: + raise ValueError("No valid working directory (cwd) provided.") + elif not self._validate_path(cwd): + # A cwd was provided, but it is not valid + raise ValueError(f"Working directory (cwd) '{cwd}' is not valid. It must be within the base path.") + + # Populate the history with the current working directory + self.set_path(os.path.realpath(cwd)) + + @property + def path(self) -> str: + """Return the path of the current page.""" + assert len(self.history) > 0 + return self.history[-1][0] + + def _validate_path(self, path: str) -> bool: + """Validates the path to ensure it is within the base path. + + Arguments: + path: The path to validate. + Returns: + True if the path is valid, False otherwise. + """ + if self._base_path is None: + return True + + # Normalize the paths + path = os.path.realpath(path) + base = os.path.realpath(self._base_path) + + # Check if the path is within the base path + if os.path.commonpath([path, base]) != base: + return False + + return True + + def set_path(self, path: str) -> None: + """Sets the path of the current page. + This will result in the file being opened for reading. + + Arguments: + path: An absolute or relative path of the file or directory to open." + """ + + # Handle relative paths + path = os.path.expanduser(path) + if not os.path.isabs(path): + if os.path.isfile(self.path): + path = os.path.abspath(os.path.join(os.path.dirname(self.path), path)) + elif os.path.isdir(self.path): + path = os.path.abspath(os.path.join(self.path, path)) + # If neither a file or a directory, take it verbatim + + # Validating the path wrt. the base path is done in _open_path + path = os.path.realpath(path) + + self.history.append((path, time.time())) + self._open_path(path) + self.viewport_current_page = 0 + self.find_on_page_query = None + self.find_on_page_viewport = None + + @property + def viewport(self) -> str: + """Return the content of the current viewport.""" + bounds = self.viewport_pages[self.viewport_current_page] + return self.page_content[bounds[0] : bounds[1]] + + @property + def page_content(self) -> str: + """Return the full contents of the current page.""" + return self._page_content + + def _set_page_content(self, content: str, split_pages: bool = True) -> None: + """Sets the text content of the current page.""" + self._page_content = content + + if split_pages: + self._split_pages() + else: + self.viewport_pages = [(0, len(self._page_content))] + + if self.viewport_current_page >= len(self.viewport_pages): + self.viewport_current_page = len(self.viewport_pages) - 1 + + def page_down(self) -> None: + """Move the viewport down one page, if possible.""" + self.viewport_current_page = min(self.viewport_current_page + 1, len(self.viewport_pages) - 1) + + def page_up(self) -> None: + """Move the viewport up one page, if possible.""" + self.viewport_current_page = max(self.viewport_current_page - 1, 0) + + def find_on_page(self, query: str) -> Union[str, None]: + """Searches for the query from the current viewport forward, looping back to the start if necessary.""" + + # Did we get here via a previous find_on_page search with the same query? + # If so, map to find_next + if query == self._find_on_page_query and self.viewport_current_page == self._find_on_page_last_result: + return self.find_next() + + # Ok it's a new search start from the current viewport + self._find_on_page_query = query + viewport_match = self._find_next_viewport(query, self.viewport_current_page) + if viewport_match is None: + self._find_on_page_last_result = None + return None + else: + self.viewport_current_page = viewport_match + self._find_on_page_last_result = viewport_match + return self.viewport + + def find_next(self) -> Union[str, None]: + """Scroll to the next viewport that matches the query""" + + if self._find_on_page_query is None: + return None + + starting_viewport = self._find_on_page_last_result + if starting_viewport is None: + starting_viewport = 0 + else: + starting_viewport += 1 + if starting_viewport >= len(self.viewport_pages): + starting_viewport = 0 + + viewport_match = self._find_next_viewport(self._find_on_page_query, starting_viewport) + if viewport_match is None: + self._find_on_page_last_result = None + return None + else: + self.viewport_current_page = viewport_match + self._find_on_page_last_result = viewport_match + return self.viewport + + def _find_next_viewport(self, query: Optional[str], starting_viewport: int) -> Union[int, None]: + """Search for matches between the starting viewport looping when reaching the end.""" + + if query is None: + return None + + # Normalize the query, and convert to a regular expression + nquery = re.sub(r"\*", "__STAR__", query) + nquery = " " + (" ".join(re.split(r"\W+", nquery))).strip() + " " + nquery = nquery.replace(" __STAR__ ", "__STAR__ ") # Merge isolated stars with prior word + nquery = nquery.replace("__STAR__", ".*").lower() + + if nquery.strip() == "": + return None + + idxs: List[int] = list() + idxs.extend(range(starting_viewport, len(self.viewport_pages))) + idxs.extend(range(0, starting_viewport)) + + for i in idxs: + bounds = self.viewport_pages[i] + content = self.page_content[bounds[0] : bounds[1]] + + # TODO: Remove markdown links and images + ncontent = " " + (" ".join(re.split(r"\W+", content))).strip().lower() + " " + if re.search(nquery, ncontent): + return i + + return None + + def open_path(self, path: str) -> str: + """Open a file or directory in the file surfer.""" + self.set_path(path) + return self.viewport + + def _split_pages(self) -> None: + """Split the page contents into pages that are approximately the viewport size. Small deviations are permitted to ensure words are not broken.""" + # Handle empty pages + if len(self._page_content) == 0: + self.viewport_pages = [(0, 0)] + return + + # Break the viewport into pages + self.viewport_pages = [] + start_idx = 0 + while start_idx < len(self._page_content): + end_idx = min(start_idx + self.viewport_size, len(self._page_content)) # type: ignore[operator] + # Adjust to end on a space + while end_idx < len(self._page_content) and self._page_content[end_idx - 1] not in [" ", "\t", "\r", "\n"]: + end_idx += 1 + self.viewport_pages.append((start_idx, end_idx)) + start_idx = end_idx + + def _open_path( + self, + path: str, + ) -> None: + """Open a file for reading, converting it to Markdown in the process. + + Arguments: + path: The path of the file or directory to open. + """ + + if not self._validate_path(path): + # Not robust to TOCTOU issues. + # Mitigate by running with limited permissions, or use a sandbox. + self.page_title = "FileNotFoundError" + self._set_page_content(f"# FileNotFoundError\n\nFile not found: {path}") + else: + try: + if os.path.isdir(path): # TODO: Fix markdown_converter types + res = self._markdown_converter.convert_stream( # type: ignore + io.BytesIO(self._fetch_local_dir(path).encode("utf-8")), file_extension=".txt" + ) + assert self._validate_path(path) + self.page_title = res.title + self._set_page_content(res.text_content, split_pages=False) + else: + res = self._markdown_converter.convert_local(path) + assert self._validate_path(path) + self.page_title = res.title + self._set_page_content(res.text_content) + except UnsupportedFormatException: + self.page_title = "UnsupportedFormatException" + self._set_page_content(f"# UnsupportedFormatException\n\nCannot preview '{path}' as Markdown.") + except FileConversionException: + self.page_title = "FileConversionException." + self._set_page_content(f"# FileConversionException\n\nError converting '{path}' to Markdown.") + except FileNotFoundError: + self.page_title = "FileNotFoundError" + self._set_page_content(f"# FileNotFoundError\n\nFile not found: {path}") + + def _fetch_local_dir(self, local_path: str) -> str: + """Render a local directory listing in HTML to assist with local file browsing via the "file://" protocol. + Through rendered in HTML, later parts of the pipeline will convert the listing to Markdown. + + Arguments: + local_path: A path to the local directory whose contents are to be listed. + + Returns: + A directory listing, rendered in HTML. + """ + listing = f""" +# Index of {local_path} + +| Name | Size | Date Modified | +| ---- | ---- | ------------- | +| .. (parent directory) | | | +""" + for entry in os.listdir(local_path): + size = "" + full_path = os.path.join(local_path, entry) + + mtime = "" + try: + mtime = datetime.datetime.fromtimestamp(os.path.getmtime(full_path)).strftime("%Y-%m-%d %H:%M") + except Exception as e: + # Handles PermissionError, etc. + mtime = f"N/A: {type(e).__name__}" + + if os.path.isdir(full_path): + entry = entry + os.path.sep + else: + try: + size = str(os.path.getsize(full_path)) + except Exception as e: + # Handles PermissionError, etc. + size = f"N/A: {type(e).__name__}" + + listing += f"| {entry} | {size} | {mtime} |\n" + return listing diff --git a/agent_dhal/agentdhal_extensions/agents/file_surfer/_tool_definitions.py b/agent_dhal/agentdhal_extensions/agents/file_surfer/_tool_definitions.py new file mode 100644 index 0000000..4277b49 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/file_surfer/_tool_definitions.py @@ -0,0 +1,50 @@ +from agentdhal_core.tools import ParametersSchema, ToolSchema + +TOOL_OPEN_PATH = ToolSchema( + name="open_path", + description="Open a local file or directory at a path in the text-based file browser and return current viewport content.", + parameters=ParametersSchema( + type="object", + properties={ + "path": { + "type": "string", + "description": "The relative or absolute path of a local file to visit.", + }, + }, + required=["path"], + ), +) + + +TOOL_PAGE_UP = ToolSchema( + name="page_up", + description="Scroll the viewport UP one page-length in the current file and return the new viewport content.", +) + + +TOOL_PAGE_DOWN = ToolSchema( + name="page_down", + description="Scroll the viewport DOWN one page-length in the current file and return the new viewport content.", +) + + +TOOL_FIND_ON_PAGE_CTRL_F = ToolSchema( + name="find_on_page_ctrl_f", + description="Scroll the viewport to the first occurrence of the search string. This is equivalent to Ctrl+F.", + parameters=ParametersSchema( + type="object", + properties={ + "search_string": { + "type": "string", + "description": "The string to search for on the page. This search string supports wildcards like '*'", + }, + }, + required=["search_string"], + ), +) + + +TOOL_FIND_NEXT = ToolSchema( + name="find_next", + description="Scroll the viewport to next occurrence of the search string.", +) diff --git a/agent_dhal/agentdhal_extensions/agents/magentic_one/__init__.py b/agent_dhal/agentdhal_extensions/agents/magentic_one/__init__.py new file mode 100644 index 0000000..a046770 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/magentic_one/__init__.py @@ -0,0 +1,10 @@ +try: + from ._magentic_one_coder_agent import MagenticOneCoderAgent +except ImportError as e: + raise ImportError( + "Dependencies for MagenticOneCoderAgent not found. " + 'Please install autogen-ext with the "magentic-one" extra: ' + 'pip install "agentdhal-ext[magentic-one]"' + ) from e + +__all__ = ["MagenticOneCoderAgent"] diff --git a/agent_dhal/agentdhal_extensions/agents/magentic_one/_magentic_one_coder_agent.py b/agent_dhal/agentdhal_extensions/agents/magentic_one/_magentic_one_coder_agent.py new file mode 100644 index 0000000..7df2a77 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/magentic_one/_magentic_one_coder_agent.py @@ -0,0 +1,41 @@ +from typing import Any + +from agentdhal_agentchat.agents import AssistantAgent +from agentdhal_core.models import ( + ChatCompletionClient, +) + +MAGENTIC_ONE_CODER_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills." + +MAGENTIC_ONE_CODER_SYSTEM_MESSAGE = """You are a helpful AI assistant. +Solve tasks using your coding and language skills. +In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute. + 1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself. + 2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly. +Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill. +When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user. +Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use the 'print' function for the output when relevant. Check the execution result returned by the user. +If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try. +When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.""" + + +class MagenticOneCoderAgent(AssistantAgent): + """An agent, used by MagenticOne that provides coding assistance using an LLM model client. + + The prompts and description are sealed, to replicate the original MagenticOne configuration. See AssistantAgent if you wish to modify these values. + """ + + component_provider_override = "agentdhal_extensions.agents.magentic_one.MagenticOneCoderAgent" + + def __init__( + self, + name: str, + model_client: ChatCompletionClient, + **kwargs: Any, + ): + super().__init__( + name, + model_client, + description=MAGENTIC_ONE_CODER_DESCRIPTION, + system_message=MAGENTIC_ONE_CODER_SYSTEM_MESSAGE, + ) diff --git a/agent_dhal/agentdhal_extensions/agents/openai/__init__.py b/agent_dhal/agentdhal_extensions/agents/openai/__init__.py new file mode 100644 index 0000000..91936e6 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/openai/__init__.py @@ -0,0 +1,7 @@ +from ._openai_agent import OpenAIAgent +from ._openai_assistant_agent import OpenAIAssistantAgent + +__all__ = [ + "OpenAIAgent", + "OpenAIAssistantAgent", +] diff --git a/agent_dhal/agentdhal_extensions/agents/openai/_openai_agent.py b/agent_dhal/agentdhal_extensions/agents/openai/_openai_agent.py new file mode 100644 index 0000000..8d62827 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/openai/_openai_agent.py @@ -0,0 +1,682 @@ +import asyncio +import logging +from typing import ( + Any, + AsyncGenerator, + Dict, + Iterable, + List, + Literal, + Mapping, + Optional, + Sequence, + Type, + Union, + cast, +) + +from agentdhal_agentchat import EVENT_LOGGER_NAME +from agentdhal_agentchat.agents import BaseChatAgent +from agentdhal_agentchat.base import Response +from agentdhal_agentchat.messages import ( + AgentEvent, + BaseChatMessage, + ChatMessage, + HandoffMessage, + MultiModalMessage, + StopMessage, + TextMessage, + ToolCallSummaryMessage, +) +from agentdhal_core import CancellationToken, Component +from agentdhal_core.models import UserMessage +from pydantic import BaseModel, Field +from typing_extensions import NotRequired, TypedDict + +from openai import AsyncAzureOpenAI, AsyncOpenAI # type: ignore + +# Number of characters to display when previewing image content in logs and UI +# Base64 encoded images can be very long, so we truncate for readability +IMAGE_CONTENT_PREVIEW_LENGTH = 50 + +# NOTE: We use the new Responses API, so ChatCompletion imports are not needed. + +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + + +# TypedDict classes for built-in tool configurations +class FileSearchToolConfig(TypedDict): + """Configuration for file_search tool.""" + + type: Literal["file_search"] + vector_store_ids: List[str] # required - The IDs of the vector stores to search + max_num_results: NotRequired[int] # optional + ranking_options: NotRequired[Dict[str, Any]] # optional + filters: NotRequired[Dict[str, Any]] # optional + + +class WebSearchToolConfig(TypedDict): + """Configuration for web_search_preview tool.""" + + type: Literal["web_search_preview"] + search_context_size: NotRequired[str] # optional + user_location: NotRequired[Union[str, Dict[str, Any]]] # optional - Can be string or structured location + + +class ComputerUseToolConfig(TypedDict): + """Configuration for computer_use_preview tool.""" + + type: Literal["computer_use_preview"] + display_height: int # required - Display height in pixels + display_width: int # required - Display width in pixels + environment: str # required - Environment type for computer use + + +class MCPToolConfig(TypedDict): + """Configuration for mcp tool.""" + + type: Literal["mcp"] + server_label: str # required - Label for the MCP server + server_url: str # required - URL of the MCP server + allowed_tools: NotRequired[List[str]] # optional - List of allowed tools + headers: NotRequired[Dict[str, str]] # optional - HTTP headers for requests + require_approval: NotRequired[bool] # optional - Whether to require user approval + + +class CodeInterpreterToolConfig(TypedDict): + """Configuration for code_interpreter tool.""" + + type: Literal["code_interpreter"] + container: str | Dict[str, Any] # required - Container configuration for code execution + + +class ImageGenerationToolConfig(TypedDict): + """Configuration for image_generation tool.""" + + type: Literal["image_generation"] + background: NotRequired[str] # optional - Background color or image + input_image_mask: NotRequired[str] # optional - Mask for input image editing + + +class LocalShellToolConfig(TypedDict): + """Configuration for local_shell tool. + + WARNING: This tool is only supported with the 'codex-mini-latest' model + and is available exclusively through the Responses API. + """ + + type: Literal["local_shell"] + # Note: local_shell currently has no additional parameters in the API + + +# Union type for all built-in tool configurations +BuiltinToolConfig = Union[ + FileSearchToolConfig, + WebSearchToolConfig, + ComputerUseToolConfig, + MCPToolConfig, + CodeInterpreterToolConfig, + ImageGenerationToolConfig, + LocalShellToolConfig, +] + + +# Define ImageMessage class early since it's used in _convert_message_to_openai_message +class ImageMessage(BaseChatMessage): + """A message containing an image.""" + + content: str # URL or base64 string + + def to_model_message(self) -> UserMessage: + return UserMessage(content=self.content, source=self.source) + + def to_model_text(self) -> str: + return "[image]" + + def to_text(self) -> str: + # Truncate long image content (especially base64) for better readability + # While still showing enough of the URL or content to be identifiable + if len(self.content) > IMAGE_CONTENT_PREVIEW_LENGTH: + return f"[Image: {self.content[:IMAGE_CONTENT_PREVIEW_LENGTH]}...]" + return f"[Image: {self.content}]" + + +class OpenAIMessageContent(TypedDict): + type: str + text: str + + +class OpenAIImageUrlContent(TypedDict): + url: str + + +class OpenAIImageContent(TypedDict): + type: str + image_url: OpenAIImageUrlContent + + +class OpenAIMessage(TypedDict): + role: str + content: Union[str, List[Union[OpenAIMessageContent, OpenAIImageContent]]] + + +def _convert_message_to_openai_message( + message: Union[TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage], +) -> OpenAIMessage: + """Convert an AutoGen message to an OpenAI message format.""" + if isinstance(message, TextMessage): + if message.source == "user": + return {"role": "user", "content": str(message.content)} + elif message.source == "system": + return {"role": "system", "content": str(message.content)} + elif message.source == "assistant": + return {"role": "assistant", "content": str(message.content)} + else: + return {"role": "user", "content": str(message.content)} + elif isinstance(message, MultiModalMessage): + content_parts: List[Union[OpenAIMessageContent, OpenAIImageContent]] = [] + for part in message.content: + if isinstance(part, TextMessage): + content_parts.append({"type": "text", "text": str(part.content)}) + elif isinstance(part, ImageMessage): + image_content = str(part.content) + content_parts.append({"type": "image_url", "image_url": {"url": image_content}}) + return {"role": "user", "content": content_parts} + else: + return {"role": "user", "content": str(message.content)} + + +class OpenAIAgentState(BaseModel): + type: str = Field(default="OpenAIAgentState") + response_id: Optional[str] = None + history: List[Dict[str, Any]] = Field(default_factory=list) + + +class OpenAIAgentConfig(BaseModel): + """ + Configuration model for OpenAI agent supporting OpenAI built-in tools only. + + .. versionchanged:: v0.7.0 + Added support for built-in tools in JSON configuration via _to_config and _from_config methods. + The tools field accepts built-in tool configurations (dict format) and built-in tool names (string format). + Custom tools are not supported. + """ + + name: str + description: str + model: str + instructions: str + tools: List[Dict[str, Any] | str] | None = None + temperature: Optional[float] = 1 + max_output_tokens: Optional[int] = None + json_mode: bool = False + store: bool = True + truncation: str = "disabled" + + +class OpenAIAgent(BaseChatAgent, Component[OpenAIAgentConfig]): + """ + An agent implementation that uses the OpenAI Responses API to generate responses. + + Installation: + + .. code-block:: bash + + pip install "agentdhal-ext[openai]" + # pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant + + This agent leverages the Responses API to generate responses with capabilities like: + + * Multi-turn conversations + * Built-in tool support (file_search, code_interpreter, web_search_preview, etc.) + + Currently, custom tools are not supported. + + .. versionchanged:: v0.7.0 + + Added support for built-in tool types like file_search, web_search_preview, + code_interpreter, computer_use_preview, image_generation, and mcp. + Added support for tool configurations with required and optional parameters. + + Built-in tools are split into two categories: + + **Tools that can use string format** (no required parameters): + + - web_search_preview: Can be used as "web_search_preview" or with optional config + (user_location, search_context_size) + - image_generation: Can be used as "image_generation" or with optional config (background, input_image_mask) + - local_shell: Can be used as "local_shell" (WARNING: Only works with codex-mini-latest model) + + **Tools that REQUIRE dict configuration** (have required parameters): + + - file_search: MUST use dict with vector_store_ids (List[str]) + - computer_use_preview: MUST use dict with display_height (int), display_width (int), environment (str) + - code_interpreter: MUST use dict with container (str) + - mcp: MUST use dict with server_label (str), server_url (str) + + Using required-parameter tools in string format will raise a ValueError with helpful error messages. + The tools parameter type annotation only accepts string values for tools that don't require parameters. + + Note: + Custom tools (autogen FunctionTool or other user-defined tools) are not supported by this agent. + Only OpenAI built-in tools provided via the Responses API are supported. + + + Args: + name (str): Name of the agent + description (str): Description of the agent's purpose + client (Union[AsyncOpenAI, AsyncAzureOpenAI]): OpenAI client instance + model (str): Model to use (e.g. "gpt-4.1") + instructions (str): System instructions for the agent + tools (Optional[Iterable[Union[str, BuiltinToolConfig]]]): Tools the agent can use. + Supported string values (no required parameters): "web_search_preview", "image_generation", "local_shell". + Dict values can provide configuration for built-in tools with parameters. + Required parameters for built-in tools: + - file_search: vector_store_ids (List[str]) + - computer_use_preview: display_height (int), display_width (int), environment (str) + - code_interpreter: container (str) + - mcp: server_label (str), server_url (str) + Optional parameters for built-in tools: + - file_search: max_num_results (int), ranking_options (dict), filters (dict) + - web_search_preview: user_location (str or dict), search_context_size (int) + - image_generation: background (str), input_image_mask (str) + - mcp: allowed_tools (List[str]), headers (dict), require_approval (bool) + Special tools with model restrictions: + - local_shell: Only works with "codex-mini-latest" model (WARNING: Very limited support) + Custom tools are not supported. + temperature (Optional[float]): Temperature for response generation (default: 1) + max_output_tokens (Optional[int]): Maximum output tokens + json_mode (bool): Whether to use JSON mode (default: False) + store (bool): Whether to store conversations (default: True) + truncation (str): Truncation strategy (default: "disabled") + + Example: + + Basic usage with built-in tools: + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.agents.openai import OpenAIAgent + from openai import AsyncOpenAI + + + async def example(): + client = AsyncOpenAI() + agent = OpenAIAgent( + name="SimpleAgent", + description="A simple OpenAI agent using the Responses API", + client=client, + model="gpt-4.1", + instructions="You are a helpful assistant.", + tools=["web_search_preview"], # Only tools without required params + ) + await Console(agent.run_stream(task="Search for recent AI developments")) + + + asyncio.run(example()) + + Usage with configured built-in tools: + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.agents.openai import OpenAIAgent + from openai import AsyncOpenAI + + + async def example_with_configs(): + client = AsyncOpenAI() + # Configure tools with required and optional parameters + tools = [ + # { + # "type": "file_search", + # "vector_store_ids": ["vs_abc123"], # required + # "max_num_results": 10, # optional + # }, + # { + # "type": "computer_use_preview", + # "display_height": 1024, # required + # "display_width": 1280, # required + # "environment": "linux", # required + # }, + { + "type": "code_interpreter", + "container": {"type": "auto"}, # required + }, + # { + # "type": "mcp", + # "server_label": "my-mcp-server", # required + # "server_url": "http://localhost:3000", # required + # }, + { + "type": "web_search_preview", + "user_location": { # optional - structured location + "type": "approximate", # required: "approximate" or "exact" + "country": "US", # optional + "region": "CA", # optional + "city": "San Francisco", # optional + }, + "search_context_size": "low", # optional + }, + # "image_generation", # Simple tools can still use string format + ] + + agent = OpenAIAgent( + name="ConfiguredAgent", + description="An agent with configured tools", + client=client, + model="gpt-4.1", + instructions="You are a helpful assistant with specialized tools.", + tools=tools, # type: ignore + ) + await Console(agent.run_stream(task="Search for recent AI developments")) + + + asyncio.run(example_with_configs()) + + + Note: + Custom tools are not supported by OpenAIAgent. Use only built-in tools from the Responses API. + + """ + + component_config_schema = OpenAIAgentConfig + component_provider_override = "agentdhal_extensions.agents.openai.OpenAIAgent" + + def __init__( + self: "OpenAIAgent", + name: str, + description: str, + client: Union[AsyncOpenAI, AsyncAzureOpenAI], + model: str, + instructions: str, + tools: Optional[ + Iterable[ + Union[ + Literal["web_search_preview", "image_generation", "local_shell"], + BuiltinToolConfig, + ] + ] + ] = None, + temperature: Optional[float] = 1, + max_output_tokens: Optional[int] = None, + json_mode: bool = False, + store: bool = True, + truncation: str = "disabled", + ) -> None: + super().__init__(name, description) + self._client: Union[AsyncOpenAI, AsyncAzureOpenAI] = client + self._model: str = model + self._instructions: str = instructions + self._temperature: Optional[float] = temperature + self._max_output_tokens: Optional[int] = max_output_tokens + self._json_mode: bool = json_mode + self._store: bool = store + self._truncation: str = truncation + self._last_response_id: Optional[str] = None + self._message_history: List[Dict[str, Any]] = [] + self._tools: List[Dict[str, Any]] = [] + if tools is not None: + for tool in tools: + if isinstance(tool, str): + # Handle built-in tool types + self._add_builtin_tool(tool) + elif isinstance(tool, dict) and "type" in tool: + # Handle configured built-in tools + self._tools.append(cast(dict[str, Any], tool)) + else: + raise ValueError(f"Unsupported tool type: {type(tool)}") + + def _add_builtin_tool(self, tool_name: str) -> None: + """Add a built-in tool by name.""" + # Skip if an identical tool has already been registered (idempotent behaviour) + if any(td.get("type") == tool_name for td in self._tools): + return # Duplicate – ignore rather than raise to stay backward-compatible + # Only allow string format for tools that don't require parameters + if tool_name == "web_search_preview": + self._tools.append({"type": "web_search_preview"}) + elif tool_name == "image_generation": + self._tools.append({"type": "image_generation"}) + elif tool_name == "local_shell": + # Special handling for local_shell - very limited model support + if self._model != "codex-mini-latest": + raise ValueError( + f"Tool 'local_shell' is only supported with model 'codex-mini-latest', " + f"but current model is '{self._model}'. " + f"This tool is available exclusively through the Responses API and has severe limitations. " + f"Consider using agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool with " + f"agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor for shell execution instead." + ) + self._tools.append({"type": "local_shell"}) + elif tool_name in ["file_search", "code_interpreter", "computer_use_preview", "mcp"]: + # These tools require specific parameters and must use dict configuration + raise ValueError( + f"Tool '{tool_name}' requires specific parameters and cannot be added using string format. " + f"Use dict configuration instead. Required parameters for {tool_name}: " + f"{self._get_required_params_help(tool_name)}" + ) + else: + raise ValueError(f"Unsupported built-in tool type: {tool_name}") + + def _get_required_params_help(self, tool_name: str) -> str: + """Get help text for required parameters of a tool.""" + help_text = { + "file_search": "vector_store_ids (List[str])", + "code_interpreter": "container (str | dict)", + "computer_use_preview": "display_height (int), display_width (int), environment (str)", + "mcp": "server_label (str), server_url (str)", + } + return help_text.get(tool_name, "unknown parameters") + + def _convert_message_to_dict(self, message: OpenAIMessage) -> Dict[str, Any]: + """Convert an OpenAIMessage to a Dict[str, Any].""" + return dict(message) + + @property + def produced_message_types( + self: "OpenAIAgent", + ) -> Sequence[ + Union[ + Type[TextMessage], + Type[MultiModalMessage], + Type[StopMessage], + Type[ToolCallSummaryMessage], + Type[HandoffMessage], + ] + ]: + """Return the types of messages that this agent can produce.""" + return [TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage] + + # Custom tool execution is not supported by this agent. + + def _build_api_parameters(self: "OpenAIAgent", messages: List[Dict[str, Any]]) -> Dict[str, Any]: + has_system_message = any(msg.get("role") == "system" for msg in messages) + if self._instructions and not has_system_message: + messages = [{"role": "system", "content": self._instructions}] + messages + api_params: Dict[str, Any] = { + "model": self._model, + "input": messages, # Responses API expects 'input' + } + if self._temperature is not None: + api_params["temperature"] = self._temperature + if self._max_output_tokens is not None: + api_params["max_output_tokens"] = self._max_output_tokens + if self._tools: + api_params["tools"] = self._tools + if self._json_mode: + api_params["text"] = {"type": "json_object"} + api_params["store"] = self._store + api_params["truncation"] = self._truncation + if self._last_response_id: + api_params["previous_response_id"] = self._last_response_id + return api_params + + async def on_messages( + self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken + ) -> Response: + response = None + inner_messages: List[ + Union[AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage] + ] = [] + + async for msg in self.on_messages_stream(messages, cancellation_token): + if isinstance(msg, Response): + response = msg + # ModelClientStreamingChunkEvent does not exist in this version, so skip this check + else: + inner_messages.append(msg) + + if response is None: + raise ValueError("No response was generated") + + if response.inner_messages is None: + response.inner_messages = [] + + for msg in inner_messages: + if msg not in response.inner_messages: + response.inner_messages = list(response.inner_messages) + [msg] + + return response + + async def on_messages_stream( + self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[ + Union[ + AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage, Response + ], + None, + ]: + input_messages: List[Dict[str, Any]] = [] + + if self._message_history: + input_messages.extend(self._message_history) + + for message in messages: + if isinstance( + message, (TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage) + ): + openai_message = _convert_message_to_openai_message(message) + dict_message = self._convert_message_to_dict(openai_message) + input_messages.append(dict_message) + self._message_history.append(dict_message) + else: + msg_content = str(cast(Any, message).content) if hasattr(message, "content") else str(message) + dict_message = {"role": "user", "content": msg_content} + input_messages.append(dict_message) + self._message_history.append(dict_message) + + inner_messages: List[AgentEvent | ChatMessage] = [] + + api_params = self._build_api_parameters(input_messages) + + try: + client = cast(Any, self._client) + response_obj = await cancellation_token.link_future( + asyncio.ensure_future(client.responses.create(**api_params)) + ) + content = getattr(response_obj, "output_text", None) + response_id = getattr(response_obj, "id", None) + self._last_response_id = response_id + # Use a readable placeholder when the API returns no content to aid debugging + content_str: str = str(content) if content is not None else "[no content returned]" + self._message_history.append({"role": "assistant", "content": content_str}) + final_message = TextMessage(source=self.name, content=content_str) + response = Response(chat_message=final_message, inner_messages=inner_messages) + yield response + except Exception as e: + error_message = f"Error generating response: {str(e)}" + event_logger.error(f"API error: {error_message}", exc_info=True) + error_response = TextMessage(source=self.name, content=error_message) + yield Response(chat_message=error_response, inner_messages=inner_messages) + + async def on_reset(self: "OpenAIAgent", cancellation_token: CancellationToken) -> None: + self._last_response_id = None + self._message_history = [] + + async def save_state(self: "OpenAIAgent") -> Mapping[str, Any]: + state = OpenAIAgentState( + response_id=self._last_response_id, + history=self._message_history, + ) + return state.model_dump() + + async def load_state(self: "OpenAIAgent", state: Mapping[str, Any]) -> None: + agent_state = OpenAIAgentState.model_validate(state) + self._last_response_id = agent_state.response_id + self._message_history = agent_state.history + + def _to_config(self: "OpenAIAgent") -> OpenAIAgentConfig: + """Convert the OpenAI agent to a declarative config. + + Serializes built-in tools to their appropriate configuration formats for JSON serialization. + + Returns: + OpenAIAgentConfig: The configuration that can recreate this agent. + """ + return OpenAIAgentConfig( + name=self.name, + description=self.description, + model=self._model, + instructions=self._instructions, + tools=list(self._tools), + temperature=self._temperature, + max_output_tokens=self._max_output_tokens, + json_mode=self._json_mode, + store=self._store, + truncation=self._truncation, + ) + + @classmethod + def _from_config(cls: Type["OpenAIAgent"], config: OpenAIAgentConfig) -> "OpenAIAgent": + """Create an OpenAI agent from a declarative config. + + Handles built-in tools (from string or dict configurations). + + Args: + config: The configuration to load the agent from. + + Returns: + OpenAIAgent: The reconstructed agent. + """ + from openai import AsyncOpenAI + + client = AsyncOpenAI() + + return cls( + name=config.name, + description=config.description, + client=client, + model=config.model, + instructions=config.instructions, + tools=config.tools, # type: ignore + temperature=config.temperature, + max_output_tokens=config.max_output_tokens, + json_mode=config.json_mode, + store=config.store, + truncation=config.truncation, + ) + + # Add public API wrappers for configuration and tools + def to_config(self) -> OpenAIAgentConfig: + """Public wrapper for the private _to_config method.""" + return self._to_config() + + @classmethod + def from_config(cls, config: OpenAIAgentConfig) -> "OpenAIAgent": + """Public wrapper for the private _from_config classmethod.""" + return cls._from_config(config) + + @property + def tools(self) -> list[Any]: + """Public access to the agent's tools.""" + return self._tools + + @property + def model(self) -> str: + """Public access to the agent's model.""" + return self._model diff --git a/agent_dhal/agentdhal_extensions/agents/openai/_openai_assistant_agent.py b/agent_dhal/agentdhal_extensions/agents/openai/_openai_assistant_agent.py new file mode 100644 index 0000000..730dccc --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/openai/_openai_assistant_agent.py @@ -0,0 +1,715 @@ +import asyncio +import json +import logging +import os +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + Dict, + Iterable, + List, + Literal, + Mapping, + Optional, + Sequence, + Set, + Union, + cast, +) + +import aiofiles +from agentdhal_agentchat import EVENT_LOGGER_NAME +from agentdhal_agentchat.agents import BaseChatAgent +from agentdhal_agentchat.base import Response +from agentdhal_agentchat.messages import ( + BaseAgentEvent, + BaseChatMessage, + TextMessage, + ToolCallExecutionEvent, + ToolCallRequestEvent, +) +from agentdhal_core import CancellationToken, FunctionCall, Image +from agentdhal_core.models import ChatCompletionClient, FunctionExecutionResult +from agentdhal_core.tools import FunctionTool, Tool +from pydantic import BaseModel, Field + +from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI, NotGiven +from openai.pagination import AsyncCursorPage +from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads +from openai.types import FileObject +from openai.types.beta import thread_update_params +from openai.types.beta.assistant import Assistant +from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam +from openai.types.beta.assistant_tool_param import AssistantToolParam +from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam +from openai.types.beta.file_search_tool_param import FileSearchToolParam +from openai.types.beta.function_tool_param import FunctionToolParam +from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter +from openai.types.beta.threads import Message, MessageDeleted, Run +from openai.types.beta.threads.image_url_content_block_param import ImageURLContentBlockParam +from openai.types.beta.threads.image_url_param import ImageURLParam +from openai.types.beta.threads.message_content_part_param import ( + MessageContentPartParam, +) +from openai.types.beta.threads.text_content_block_param import TextContentBlockParam +from openai.types.shared_params.function_definition import FunctionDefinition +from openai.types.vector_store import VectorStore + +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + + +def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam": + """Convert an autogen Tool to an OpenAI Assistant function tool parameter.""" + + schema = tool.schema + parameters: Dict[str, object] = {} + if "parameters" in schema: + parameters = { + "type": schema["parameters"]["type"], + "properties": schema["parameters"]["properties"], + } + if "required" in schema["parameters"]: + parameters["required"] = schema["parameters"]["required"] + + function_def = FunctionDefinition( + name=schema["name"], + description=schema.get("description", ""), + parameters=parameters, + ) + return FunctionToolParam(type="function", function=function_def) + + +class OpenAIAssistantAgentState(BaseModel): + type: str = Field(default="OpenAIAssistantAgentState") + assistant_id: Optional[str] = None + thread_id: Optional[str] = None + initial_message_ids: List[str] = Field(default_factory=list) + vector_store_id: Optional[str] = None + uploaded_file_ids: List[str] = Field(default_factory=list) + + +class OpenAIAssistantAgent(BaseChatAgent): + """An agent implementation that uses the Assistant API to generate responses. + + Installation: + + .. code-block:: bash + + pip install "agentdhal-ext[openai]" # For OpenAI Assistant + # pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant + + + This agent leverages the Assistant API to create AI assistants with capabilities like: + + * Code interpretation and execution + * File handling and search + * Custom function calling + * Multi-turn conversations + + The agent maintains a thread of conversation and can use various tools including + + * Code interpreter: For executing code and working with files + * File search: For searching through uploaded documents + * Custom functions: For extending capabilities with user-defined tools + + Key Features: + + * Supports multiple file formats including code, documents, images + * Can handle up to 128 tools per assistant + * Maintains conversation context in threads + * Supports file uploads for code interpreter and search + * Vector store integration for efficient file search + * Automatic file parsing and embedding + + You can use an existing thread or assistant by providing the `thread_id` or `assistant_id` parameters. + + Examples: + + Use the assistant to analyze data in a CSV file: + + .. code-block:: python + + from openai import AsyncOpenAI + from agentdhal_core import CancellationToken + import asyncio + from agentdhal_extensions.agents.openai import OpenAIAssistantAgent + from agentdhal_agentchat.messages import TextMessage + + + async def example(): + cancellation_token = CancellationToken() + + # Create an OpenAI client + client = AsyncOpenAI(api_key="your-api-key", base_url="your-base-url") + + # Create an assistant with code interpreter + assistant = OpenAIAssistantAgent( + name="PythonHelper", + description="Helps with Python programming", + client=client, + model="gpt-4", + instructions="You are a helpful Python programming assistant.", + tools=["code_interpreter"], + ) + + # Upload files for the assistant to use + await assistant.on_upload_for_code_interpreter("data.csv", cancellation_token) + + # Get response from the assistant + response = await assistant.on_messages( + [TextMessage(source="user", content="Analyze the data in data.csv")], cancellation_token + ) + + print(response) + + # Clean up resources + await assistant.delete_uploaded_files(cancellation_token) + await assistant.delete_assistant(cancellation_token) + + + asyncio.run(example()) + + Use Azure OpenAI Assistant with AAD authentication: + + .. code-block:: python + + from openai import AsyncAzureOpenAI + import asyncio + from azure.identity import DefaultAzureCredential, get_bearer_token_provider + from agentdhal_core import CancellationToken + from agentdhal_extensions.agents.openai import OpenAIAssistantAgent + from agentdhal_agentchat.messages import TextMessage + + + async def example(): + cancellation_token = CancellationToken() + + # Create an Azure OpenAI client + token_provider = get_bearer_token_provider(DefaultAzureCredential()) + client = AsyncAzureOpenAI( + azure_deployment="YOUR_AZURE_DEPLOYMENT", + api_version="YOUR_API_VERSION", + azure_endpoint="YOUR_AZURE_ENDPOINT", + azure_ad_token_provider=token_provider, + ) + + # Create an assistant with code interpreter + assistant = OpenAIAssistantAgent( + name="PythonHelper", + description="Helps with Python programming", + client=client, + model="gpt-4o", + instructions="You are a helpful Python programming assistant.", + tools=["code_interpreter"], + ) + + # Get response from the assistant + response = await assistant.on_messages([TextMessage(source="user", content="Hello.")], cancellation_token) + + print(response) + + # Clean up resources + await assistant.delete_assistant(cancellation_token) + + + asyncio.run(example()) + + Args: + name (str): Name of the assistant + description (str): Description of the assistant's purpose + client (AsyncOpenAI | AsyncAzureOpenAI): OpenAI client or Azure OpenAI client instance + model (str): Model to use (e.g. "gpt-4") + instructions (str): System instructions for the assistant + tools (Optional[Iterable[Union[Literal["code_interpreter", "file_search"], Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]]]): Tools the assistant can use + assistant_id (Optional[str]): ID of existing assistant to use + thread_id (Optional[str]): ID of existing thread to use + metadata (Optional[Dict[str, str]]): Additional metadata for the assistant. + response_format (Optional[AssistantResponseFormatOptionParam]): Response format settings + temperature (Optional[float]): Temperature for response generation + tool_resources (Optional[ToolResources]): Additional tool configuration + top_p (Optional[float]): Top p sampling parameter + """ + + def __init__( + self, + name: str, + description: str, + client: AsyncOpenAI | AsyncAzureOpenAI, + model: str, + instructions: str, + tools: Optional[ + Iterable[ + Union[ + Literal["code_interpreter", "file_search"], + Tool | Callable[..., Any] | Callable[..., Awaitable[Any]], + ] + ] + ] = None, + assistant_id: Optional[str] = None, + thread_id: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + response_format: Optional["AssistantResponseFormatOptionParam"] = None, + temperature: Optional[float] = None, + tool_resources: Optional["ToolResources"] = None, + top_p: Optional[float] = None, + ) -> None: + if isinstance(client, ChatCompletionClient): + raise ValueError( + "Incorrect client passed to OpenAIAssistantAgent. Please use an OpenAI AsyncClient instance instead of an AutoGen ChatCompletionClient instance." + ) + + super().__init__(name, description) + if tools is None: + tools = [] + + # Store original tools and converted tools separately + self._original_tools: List[Tool] = [] + converted_tools: List["AssistantToolParam"] = [] + for tool in tools: + if isinstance(tool, str): + if tool == "code_interpreter": + converted_tools.append(CodeInterpreterToolParam(type="code_interpreter")) + elif tool == "file_search": + converted_tools.append(FileSearchToolParam(type="file_search")) + elif isinstance(tool, Tool): + self._original_tools.append(tool) + converted_tools.append(_convert_tool_to_function_param(tool)) + elif callable(tool): + if hasattr(tool, "__doc__") and tool.__doc__ is not None: + description = tool.__doc__ + else: + description = "" + function_tool = FunctionTool(tool, description=description) + self._original_tools.append(function_tool) + converted_tools.append(_convert_tool_to_function_param(function_tool)) + else: + raise ValueError(f"Unsupported tool type: {type(tool)}") + + self._client = client + self._assistant: Optional["Assistant"] = None + self._thread: Optional["Thread"] = None + self._init_thread_id = thread_id + self._model = model + self._instructions = instructions + self._api_tools = converted_tools + self._assistant_id = assistant_id + self._metadata = metadata + self._response_format = response_format + self._temperature = temperature + self._tool_resources = tool_resources + self._top_p = top_p + self._vector_store_id: Optional[str] = None + self._uploaded_file_ids: List[str] = [] + + # Variables to track initial state + self._initial_message_ids: Set[str] = set() + self._initial_state_retrieved: bool = False + + async def _ensure_initialized(self) -> None: + """Ensure assistant and thread are created.""" + if self._assistant is None: + if self._assistant_id: + self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id) # type: ignore[reportDeprecated] + else: + self._assistant = await self._client.beta.assistants.create( # type: ignore[reportDeprecated] + model=self._model, + description=self.description, + instructions=self._instructions, + tools=self._api_tools, + metadata=self._metadata, + response_format=self._response_format if self._response_format else NOT_GIVEN, # type: ignore + temperature=self._temperature, + tool_resources=self._tool_resources if self._tool_resources else NOT_GIVEN, # type: ignore + top_p=self._top_p, + ) + + if self._thread is None: + if self._init_thread_id: + self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id) # type: ignore[reportDeprecated] + else: + self._thread = await self._client.beta.threads.create() # type: ignore[reportDeprecated] + + # Retrieve initial state only once + if not self._initial_state_retrieved: + await self._retrieve_initial_state() + self._initial_state_retrieved = True + + async def _retrieve_initial_state(self) -> None: + """Retrieve and store the initial state of messages and runs.""" + # Retrieve all initial message IDs + initial_message_ids: Set[str] = set() + after: str | NotGiven = NOT_GIVEN + while True: + msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list( # type: ignore[reportDeprecated] + self._thread_id, after=after, order="asc", limit=100 + ) + for msg in msgs.data: + initial_message_ids.add(msg.id) + if not msgs.has_next_page(): + break + after = msgs.data[-1].id + self._initial_message_ids = initial_message_ids + + @property + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + """The types of messages that the assistant agent produces.""" + return (TextMessage,) + + @property + def threads(self) -> AsyncThreads: + return self._client.beta.threads + + @property + def runs(self) -> AsyncRuns: + return self._client.beta.threads.runs + + @property + def messages(self) -> AsyncMessages: + return self._client.beta.threads.messages + + @property + def _get_assistant_id(self) -> str: + if self._assistant is None: + raise ValueError("Assistant not initialized") + return self._assistant.id + + @property + def _thread_id(self) -> str: + if self._thread is None: + raise ValueError("Thread not initialized") + return self._thread.id + + async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str: + """Execute a tool call and return the result.""" + if not self._original_tools: + raise ValueError("No tools are available.") + tool = next((t for t in self._original_tools if t.name == tool_call.name), None) + if tool is None: + raise ValueError(f"The tool '{tool_call.name}' is not available.") + arguments = json.loads(tool_call.arguments) + result = await tool.run_json(arguments, cancellation_token, call_id=tool_call.id) + return tool.return_value_as_string(result) + + async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: + """Handle incoming messages and return a response.""" + + async for message in self.on_messages_stream(messages, cancellation_token): + if isinstance(message, Response): + return message + raise AssertionError("The stream should have returned the final result.") + + async def on_messages_stream( + self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]: + """Handle incoming messages and return a response.""" + await self._ensure_initialized() + + # Process all messages in sequence + for message in messages: + await self.handle_incoming_message(message, cancellation_token) + + # Inner messages for tool calls + inner_messages: List[BaseAgentEvent | BaseChatMessage] = [] + + # Create and start a run + run: Run = await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.runs.create( # type: ignore[reportDeprecated] + thread_id=self._thread_id, + assistant_id=self._get_assistant_id, + ) + ) + ) + + # Wait for run completion by polling + while True: + run = await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.runs.retrieve( # type: ignore[reportDeprecated] + thread_id=self._thread_id, + run_id=run.id, + ) + ) + ) + + if run.status == "failed": + raise ValueError(f"Run failed: {run.last_error}") + + # If the run requires action (function calls), execute tools and continue + if run.status == "requires_action" and run.required_action is not None: + tool_calls: List[FunctionCall] = [] + for required_tool_call in run.required_action.submit_tool_outputs.tool_calls: + if required_tool_call.type == "function": + tool_calls.append( + FunctionCall( + id=required_tool_call.id, + name=required_tool_call.function.name, + arguments=required_tool_call.function.arguments, + ) + ) + + # Add tool call message to inner messages + tool_call_msg = ToolCallRequestEvent(source=self.name, content=tool_calls) + inner_messages.append(tool_call_msg) + event_logger.debug(tool_call_msg) + yield tool_call_msg + + # Execute tool calls and get results + tool_outputs: List[FunctionExecutionResult] = [] + for tool_call in tool_calls: + try: + result = await self._execute_tool_call(tool_call, cancellation_token) + is_error = False + except Exception as e: + result = f"Error: {e}" + is_error = True + tool_outputs.append( + FunctionExecutionResult( + content=result, call_id=tool_call.id, is_error=is_error, name=tool_call.name + ) + ) + + # Add tool result message to inner messages + tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs) + inner_messages.append(tool_result_msg) + event_logger.debug(tool_result_msg) + yield tool_result_msg + + # Submit tool outputs back to the run + run = await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.runs.submit_tool_outputs( # type: ignore[reportDeprecated] + thread_id=self._thread_id, + run_id=run.id, + tool_outputs=[{"tool_call_id": t.call_id, "output": t.content} for t in tool_outputs], + ) + ) + ) + continue + + if run.status == "completed": + break + + await asyncio.sleep(0.5) + + # Get messages after run completion + assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1) # type: ignore[reportDeprecated] + ) + ) + + if not assistant_messages.data: + raise ValueError("No messages received from assistant") + + # Get the last message's content + last_message = assistant_messages.data[0] + if not last_message.content: + raise ValueError(f"No content in the last message: {last_message}") + + # Extract text content + text_content = [content for content in last_message.content if content.type == "text"] + if not text_content: + raise ValueError(f"Expected text content in the last message: {last_message.content}") + + # Return the assistant's response as a Response with inner messages + chat_message = TextMessage(source=self.name, content=text_content[0].text.value) + yield Response(chat_message=chat_message, inner_messages=inner_messages) + + async def handle_incoming_message(self, message: BaseChatMessage, cancellation_token: CancellationToken) -> None: + """Handle regular text messages by adding them to the thread.""" + content: str | List[MessageContentPartParam] | None = None + llm_message = message.to_model_message() + if isinstance(llm_message.content, str): + content = llm_message.content + else: + content = [] + for c in llm_message.content: + if isinstance(c, str): + content.append(TextContentBlockParam(text=c, type="text")) + elif isinstance(c, Image): + content.append(ImageURLContentBlockParam(image_url=ImageURLParam(url=c.data_uri), type="image_url")) + else: + raise ValueError(f"Unsupported content type: {type(c)} in {message}") + await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.messages.create( # type: ignore[reportDeprecated] + thread_id=self._thread_id, + content=content, + role="user", + ) + ) + ) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + """Handle reset command by deleting new messages and runs since initialization.""" + await self._ensure_initialized() + + # Retrieve all message IDs in the thread + new_message_ids: List[str] = [] + after: str | NotGiven = NOT_GIVEN + while True: + msgs: AsyncCursorPage[Message] = await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100) # type: ignore[reportDeprecated] + ) + ) + for msg in msgs.data: + if msg.id not in self._initial_message_ids: + new_message_ids.append(msg.id) + if not msgs.has_next_page(): + break + after = msgs.data[-1].id + + # Delete new messages + for msg_id in new_message_ids: + status: MessageDeleted = await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) # type: ignore[reportDeprecated] + ) + ) + assert status.deleted is True + + async def _upload_files(self, file_paths: str | Iterable[str], cancellation_token: CancellationToken) -> List[str]: + """Upload files and return their IDs.""" + await self._ensure_initialized() + + if isinstance(file_paths, str): + file_paths = [file_paths] + + file_ids: List[str] = [] + for file_path in file_paths: + async with aiofiles.open(file_path, mode="rb") as f: + file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read())) + file_name = os.path.basename(file_path) + + file: FileObject = await cancellation_token.link_future( + asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants")) + ) + file_ids.append(file.id) + self._uploaded_file_ids.append(file.id) + + return file_ids + + async def on_upload_for_code_interpreter( + self, file_paths: str | Iterable[str], cancellation_token: CancellationToken + ) -> None: + """Handle file uploads for the code interpreter.""" + await self._ensure_initialized() + + file_ids = await self._upload_files(file_paths, cancellation_token) + + # Update thread with the new files + thread = await cancellation_token.link_future( + asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) # type: ignore[reportDeprecated] + ) + tool_resources: ToolResources = thread.tool_resources or ToolResources() + code_interpreter: ToolResourcesCodeInterpreter = ( + tool_resources.code_interpreter or ToolResourcesCodeInterpreter() + ) + existing_file_ids: List[str] = code_interpreter.file_ids or [] + existing_file_ids.extend(file_ids) + tool_resources.code_interpreter = ToolResourcesCodeInterpreter(file_ids=existing_file_ids) + + await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.threads.update( # type: ignore[reportDeprecated] + thread_id=self._thread_id, + tool_resources=cast(thread_update_params.ToolResources, tool_resources.model_dump()), + ) + ) + ) + + async def on_upload_for_file_search( + self, file_paths: str | Iterable[str], cancellation_token: CancellationToken + ) -> None: + """Handle file uploads for file search.""" + await self._ensure_initialized() + + # Check if file_search is enabled in tools + if not any(tool.get("type") == "file_search" for tool in self._api_tools): + raise ValueError( + "File search is not enabled for this assistant. Add a file_search tool when creating the assistant." + ) + + # Create vector store if not already created + if self._vector_store_id is None: + vector_store: VectorStore = await cancellation_token.link_future( + asyncio.ensure_future(self._client.vector_stores.create()) + ) + self._vector_store_id = vector_store.id + + # Update assistant with vector store ID + await cancellation_token.link_future( + asyncio.ensure_future( + self._client.beta.assistants.update( + assistant_id=self._get_assistant_id, + tool_resources={"file_search": {"vector_store_ids": [self._vector_store_id]}}, + ) + ) + ) + + file_ids = await self._upload_files(file_paths, cancellation_token) + + # Create file batch with the file IDs + await cancellation_token.link_future( + asyncio.ensure_future( + self._client.vector_stores.file_batches.create_and_poll( + vector_store_id=self._vector_store_id, file_ids=file_ids + ) + ) + ) + + async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None: + """Delete all files that were uploaded by this agent instance.""" + await self._ensure_initialized() + for file_id in self._uploaded_file_ids: + try: + await cancellation_token.link_future(asyncio.ensure_future(self._client.files.delete(file_id=file_id))) + except Exception as e: + event_logger.error(f"Failed to delete file {file_id}: {str(e)}") + self._uploaded_file_ids = [] + + async def delete_assistant(self, cancellation_token: CancellationToken) -> None: + """Delete the assistant if it was created by this instance.""" + await self._ensure_initialized() + if self._assistant is not None and not self._assistant_id: + try: + await cancellation_token.link_future( + asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id)) # type: ignore[reportDeprecated] + ) + self._assistant = None + except Exception as e: + event_logger.error(f"Failed to delete assistant: {str(e)}") + + async def delete_vector_store(self, cancellation_token: CancellationToken) -> None: + """Delete the vector store if it was created by this instance.""" + await self._ensure_initialized() + if self._vector_store_id is not None: + try: + await cancellation_token.link_future( + asyncio.ensure_future(self._client.vector_stores.delete(vector_store_id=self._vector_store_id)) + ) + self._vector_store_id = None + except Exception as e: + event_logger.error(f"Failed to delete vector store: {str(e)}") + + async def save_state(self) -> Mapping[str, Any]: + state = OpenAIAssistantAgentState( + assistant_id=self._assistant.id if self._assistant else self._assistant_id, + thread_id=self._thread.id if self._thread else self._init_thread_id, + initial_message_ids=list(self._initial_message_ids), + vector_store_id=self._vector_store_id, + uploaded_file_ids=self._uploaded_file_ids, + ) + return state.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + agent_state = OpenAIAssistantAgentState.model_validate(state) + self._assistant_id = agent_state.assistant_id + self._init_thread_id = agent_state.thread_id + self._initial_message_ids = set(agent_state.initial_message_ids) + self._vector_store_id = agent_state.vector_store_id + self._uploaded_file_ids = agent_state.uploaded_file_ids diff --git a/agent_dhal/agentdhal_extensions/agents/video_surfer/__init__.py b/agent_dhal/agentdhal_extensions/agents/video_surfer/__init__.py new file mode 100644 index 0000000..cab75c5 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/video_surfer/__init__.py @@ -0,0 +1,3 @@ +from ._video_surfer import VideoSurfer + +__all__ = ["VideoSurfer"] diff --git a/agent_dhal/agentdhal_extensions/agents/video_surfer/_video_surfer.py b/agent_dhal/agentdhal_extensions/agents/video_surfer/_video_surfer.py new file mode 100644 index 0000000..7d57f9d --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/video_surfer/_video_surfer.py @@ -0,0 +1,172 @@ +from typing import Any, Awaitable, Callable, List, Optional + +from agentdhal_agentchat.agents import AssistantAgent +from agentdhal_core.models import ChatCompletionClient +from agentdhal_core.tools import BaseTool +from pydantic import BaseModel + +from .tools import ( + extract_audio, + get_screenshot_at, + get_video_length, + save_screenshot, + transcribe_audio_with_timestamps, + transcribe_video_screenshot, +) + + +class VideoSurfer(AssistantAgent): + """ + VideoSurfer is a specialized agent designed to answer questions about a local video file. + + Installation: + + .. code-block:: bash + + pip install "agentdhal-ext[video-surfer]" + + This agent utilizes various tools to extract information from the video, such as its length, screenshots at specific timestamps, and audio transcriptions. It processes these elements to provide detailed answers to user queries. + + Available tools: + + - :func:`~agentdhal_extensions.agents.video_surfer.tools.extract_audio` + - :func:`~agentdhal_extensions.agents.video_surfer.tools.get_video_length` + - :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_audio_with_timestamps` + - :func:`~agentdhal_extensions.agents.video_surfer.tools.get_screenshot_at` + - :func:`~agentdhal_extensions.agents.video_surfer.tools.save_screenshot` + - :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_video_screenshot` + + Args: + name (str): The name of the agent. + model_client (ChatCompletionClient): The model client used for generating responses. + tools (List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): + A list of tools or functions the agent can use. If not provided, defaults to all video tools from the action space. + description (str, optional): A brief description of the agent. Defaults to "An agent that can answer questions about a local video.". + system_message (str | None, optional): The system message guiding the agent's behavior. Defaults to a predefined message. + + Example usage: + + The following example demonstrates how to create an video surfing agent with + a model client and generate a response to a simple query about a local video + called video.mp4. + + .. code-block:: python + + + import asyncio + from agentdhal_agentchat.ui import Console + from agentdhal_agentchat.conditions import TextMentionTermination + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.agents.video_surfer import VideoSurfer + + async def main() -> None: + \"\"\" + Main function to run the video agent. + \"\"\" + # Define an agent + video_agent = VideoSurfer( + name="VideoSurfer", + model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-08-06") + ) + + # Define termination condition + termination = TextMentionTermination("TERMINATE") + + # Define a team + agent_team = RoundRobinGroupChat([video_agent], termination_condition=termination) + + # Run the team and stream messages to the console + stream = agent_team.run_stream(task="How does Adam define complex tasks in video.mp4? What concrete example of complex does his use? Can you save this example to disk as well?") + await Console(stream) + + asyncio.run(main()) + + The following example demonstrates how to create and use a VideoSurfer and UserProxyAgent with MagenticOneGroupChat. + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.ui import Console + from agentdhal_agentchat.teams import MagenticOneGroupChat + from agentdhal_agentchat.agents import UserProxyAgent + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.agents.video_surfer import VideoSurfer + + async def main() -> None: + \"\"\" + Main function to run the video agent. + \"\"\" + + model_client = OpenAIChatCompletionClient(model="gpt-4o-2024-08-06") + + # Define an agent + video_agent = VideoSurfer( + name="VideoSurfer", + model_client=model_client + ) + + web_surfer_agent = UserProxyAgent( + name="User" + ) + + # Define a team + agent_team = MagenticOneGroupChat([web_surfer_agent, video_agent], model_client=model_client,) + + # Run the team and stream messages to the console + stream = agent_team.run_stream(task="Find a latest video about magentic one on youtube and extract quotes from it that make sense.") + await Console(stream) + + asyncio.run(main()) + """ + + DEFAULT_DESCRIPTION = "An agent that can answer questions about a local video." + + DEFAULT_SYSTEM_MESSAGE = """ + You are a helpful agent that is an expert at answering questions from a video. + When asked to answer a question about a video, you should: + 1. Check if that video is available locally. + 2. Use the transcription to find which part of the video the question is referring to. + 3. Optionally use screenshots from those timestamps + 4. Provide a detailed answer to the question. + Reply with TERMINATE when the task has been completed. + """ + + def __init__( + self, + name: str, + model_client: ChatCompletionClient, + *, + tools: List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, + description: Optional[str] = None, + system_message: Optional[str] = None, + ): + super().__init__( + name=name, + model_client=model_client, + tools=tools + or [ + get_video_length, + get_screenshot_at, + save_screenshot, + self.vs_transribe_video_screenshot, + extract_audio, + transcribe_audio_with_timestamps, + ], + description=description or self.DEFAULT_DESCRIPTION, + system_message=system_message or self.DEFAULT_SYSTEM_MESSAGE, + ) + + async def vs_transribe_video_screenshot(self, video_path: str, timestamp: float) -> str: + """ + Transcribes the video screenshot at a specific timestamp. + + Args: + video_path (str): Path to the video file. + timestamp (float): Timestamp to take the screenshot. + + Returns: + str: Transcription of the video screenshot. + """ + return await transcribe_video_screenshot(video_path, timestamp, self._model_client) diff --git a/agent_dhal/agentdhal_extensions/agents/video_surfer/tools.py b/agent_dhal/agentdhal_extensions/agents/video_surfer/tools.py new file mode 100644 index 0000000..ee0ba99 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/video_surfer/tools.py @@ -0,0 +1,156 @@ +import base64 +from typing import Any, Dict, List, Tuple + +import cv2 +import ffmpeg +import numpy as np +import whisper +from agentdhal_core import Image as AGImage +from agentdhal_core.models import ( + ChatCompletionClient, + UserMessage, +) + + +def extract_audio(video_path: str, audio_output_path: str) -> str: + """ + Extracts audio from a video file and saves it as an MP3 file. + + :param video_path: Path to the video file. + :param audio_output_path: Path to save the extracted audio file. + :return: Confirmation message with the path to the saved audio file. + """ + (ffmpeg.input(video_path).output(audio_output_path, format="mp3").run(quiet=True, overwrite_output=True)) # type: ignore + return f"Audio extracted and saved to {audio_output_path}." + + +def transcribe_audio_with_timestamps(audio_path: str) -> str: + """ + Transcribes the audio file with timestamps using the Whisper model. + + :param audio_path: Path to the audio file. + :return: Transcription with timestamps. + """ + model = whisper.load_model("base") # type: ignore + result: Dict[str, Any] = model.transcribe(audio_path, task="transcribe", language="en", verbose=False) # type: ignore + + segments: List[Dict[str, Any]] = result["segments"] + transcription_with_timestamps = "" + + for segment in segments: + start: float = segment["start"] + end: float = segment["end"] + text: str = segment["text"] + transcription_with_timestamps += f"[{start:.2f} - {end:.2f}] {text}\n" + + return transcription_with_timestamps + + +def get_video_length(video_path: str) -> str: + """ + Returns the length of the video in seconds. + + :param video_path: Path to the video file. + :return: Duration of the video in seconds. + """ + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Cannot open video file {video_path}") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) + duration = frame_count / fps + cap.release() + + return f"The video is {duration:.2f} seconds long." + + +def save_screenshot(video_path: str, timestamp: float, output_path: str) -> None: + """ + Captures a screenshot at the specified timestamp and saves it to the output path. + + :param video_path: Path to the video file. + :param timestamp: Timestamp in seconds. + :param output_path: Path to save the screenshot. The file format is determined by the extension in the path. + """ + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Cannot open video file {video_path}") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_number = int(timestamp * fps) + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) + ret, frame = cap.read() + if ret: + cv2.imwrite(output_path, frame) + else: + raise IOError(f"Failed to capture frame at {timestamp:.2f}s") + cap.release() + + +async def transcribe_video_screenshot(video_path: str, timestamp: float, model_client: ChatCompletionClient) -> str: + """ + Transcribes the content of a video screenshot captured at the specified timestamp using OpenAI API. + + :param video_path: Path to the video file. + :param timestamp: Timestamp in seconds. + :param model_client: ChatCompletionClient instance. + :return: Description of the screenshot content. + """ + screenshots = get_screenshot_at(video_path, [timestamp]) + if not screenshots: + return "Failed to capture screenshot." + + _, frame = screenshots[0] + # Convert the frame to bytes and then to base64 encoding + _, buffer = cv2.imencode(".jpg", frame) + frame_bytes = buffer.tobytes() + frame_base64 = base64.b64encode(frame_bytes).decode("utf-8") + screenshot_uri = f"data:image/jpeg;base64,{frame_base64}" + + messages = [ + UserMessage( + content=[ + "Following is a screenshot from the video at {} seconds. Describe what you see here.", + AGImage.from_uri(screenshot_uri), + ], + source="tool", + ) + ] + + result = await model_client.create(messages=messages) + return str(result.content) + + +def get_screenshot_at(video_path: str, timestamps: List[float]) -> List[Tuple[float, np.ndarray[Any, Any]]]: + """ + Captures screenshots at the specified timestamps and returns them as Python objects. + + :param video_path: Path to the video file. + :param timestamps: List of timestamps in seconds. + :return: List of tuples containing timestamp and the corresponding frame (image). + Each frame is a NumPy array (height x width x channels). + """ + screenshots: List[Tuple[float, np.ndarray[Any, Any]]] = [] + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Cannot open video file {video_path}") + + fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) + duration = total_frames / fps + + for timestamp in timestamps: + if 0 <= timestamp <= duration: + frame_number = int(timestamp * fps) + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) + ret, frame = cap.read() + if ret: + # Append the timestamp and frame to the list + screenshots.append((timestamp, frame)) + else: + raise IOError(f"Failed to capture frame at {timestamp:.2f}s") + else: + raise ValueError(f"Timestamp {timestamp:.2f}s is out of range [0s, {duration:.2f}s]") + + cap.release() + return screenshots diff --git a/agent_dhal/agentdhal_extensions/agents/web_surfer/__init__.py b/agent_dhal/agentdhal_extensions/agents/web_surfer/__init__.py new file mode 100644 index 0000000..5b3efc9 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/web_surfer/__init__.py @@ -0,0 +1,4 @@ +from ._multimodal_web_surfer import MultimodalWebSurfer +from .playwright_controller import PlaywrightController + +__all__ = ["MultimodalWebSurfer", "PlaywrightController"] diff --git a/agent_dhal/agentdhal_extensions/agents/web_surfer/_events.py b/agent_dhal/agentdhal_extensions/agents/web_surfer/_events.py new file mode 100644 index 0000000..3468f41 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/web_surfer/_events.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from typing import Any, Dict + + +@dataclass +class WebSurferEvent: + source: str + message: str + url: str + action: str | None = None + arguments: Dict[str, Any] | None = None diff --git a/agent_dhal/agentdhal_extensions/agents/web_surfer/_multimodal_web_surfer.py b/agent_dhal/agentdhal_extensions/agents/web_surfer/_multimodal_web_surfer.py new file mode 100644 index 0000000..ba3b35c --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/web_surfer/_multimodal_web_surfer.py @@ -0,0 +1,988 @@ +import asyncio +import base64 +import hashlib +import io +import json +import logging +import os +import re +import sys +import time +import traceback +import warnings +from typing import ( + Any, + AsyncGenerator, + Dict, + List, + Optional, + Sequence, +) +from urllib.parse import quote_plus + +import aiofiles +import PIL.Image +from agentdhal_agentchat.agents import BaseChatAgent +from agentdhal_agentchat.base import Response +from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, MultiModalMessage, TextMessage +from agentdhal_agentchat.utils import content_to_str, remove_images +from agentdhal_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel, FunctionCall +from agentdhal_core import Image as AGImage +from agentdhal_core.models import ( + AssistantMessage, + ChatCompletionClient, + LLMMessage, + ModelFamily, + RequestUsage, + SystemMessage, + UserMessage, +) +from PIL import Image +from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright +from pydantic import BaseModel +from typing_extensions import Self + +from ._events import WebSurferEvent +from ._prompts import ( + WEB_SURFER_QA_PROMPT, + WEB_SURFER_QA_SYSTEM_MESSAGE, + WEB_SURFER_TOOL_PROMPT_MM, + WEB_SURFER_TOOL_PROMPT_TEXT, +) +from ._set_of_mark import add_set_of_mark +from ._tool_definitions import ( + TOOL_CLICK, + TOOL_HISTORY_BACK, + TOOL_HOVER, + TOOL_READ_PAGE_AND_ANSWER, + TOOL_SCROLL_DOWN, + TOOL_SCROLL_UP, + TOOL_SLEEP, + TOOL_SUMMARIZE_PAGE, + TOOL_TYPE, + TOOL_VISIT_URL, + TOOL_WEB_SEARCH, +) +from ._types import InteractiveRegion, UserContent +from .playwright_controller import PlaywrightController + +DEFAULT_CONTEXT_SIZE = 128000 + + +class MultimodalWebSurferConfig(BaseModel): + name: str + model_client: ComponentModel + downloads_folder: str | None = None + description: str | None = None + debug_dir: str | None = None + headless: bool = True + start_page: str | None = "https://www.bing.com/" + animate_actions: bool = False + to_save_screenshots: bool = False + use_ocr: bool = False + browser_channel: str | None = None + browser_data_dir: str | None = None + to_resize_viewport: bool = True + + +class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]): + """ + MultimodalWebSurfer is a multimodal agent that acts as a web surfer that can search the web and visit web pages. + + Installation: + + .. code-block:: bash + + pip install "agentdhal-ext[web-surfer]" + + It launches a chromium browser and allows the playwright to interact with the web browser and can perform a variety of actions. The browser is launched on the first call to the agent and is reused for subsequent calls. + + It must be used with a multimodal model client that supports function/tool calling, ideally GPT-4o currently. + + + When :meth:`on_messages` or :meth:`on_messages_stream` is called, the following occurs: + 1) If this is the first call, the browser is initialized and the page is loaded. This is done in :meth:`_lazy_init`. The browser is only closed when :meth:`close` is called. + 2) The method :meth:`_generate_reply` is called, which then creates the final response as below. + 3) The agent takes a screenshot of the page, extracts the interactive elements, and prepares a set-of-mark screenshot with bounding boxes around the interactive elements. + 4) The agent makes a call to the :attr:`model_client` with the SOM screenshot, history of messages, and the list of available tools. + - If the model returns a string, the agent returns the string as the final response. + - If the model returns a list of tool calls, the agent executes the tool calls with :meth:`_execute_tool` using :attr:`_playwright_controller`. + - The agent returns a final response which includes a screenshot of the page, page metadata, description of the action taken and the inner text of the webpage. + 5) If at any point the agent encounters an error, it returns the error message as the final response. + + + .. note:: + Please note that using the MultimodalWebSurfer involves interacting with a digital world designed for humans, which carries inherent risks. + Be aware that agents may occasionally attempt risky actions, such as recruiting humans for help or accepting cookie agreements without human involvement. Always ensure agents are monitored and operate within a controlled environment to prevent unintended consequences. + Moreover, be cautious that MultimodalWebSurfer may be susceptible to prompt injection attacks from webpages. + + .. note:: + + On Windows, the event loop policy must be set to `WindowsProactorEventLoopPolicy` to avoid issues with subprocesses. + + .. code-block:: python + + import sys + import asyncio + + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + + Args: + name (str): The name of the agent. + model_client (ChatCompletionClient): The model client used by the agent. Must be multimodal and support function calling. + downloads_folder (str, optional): The folder where downloads are saved. Defaults to None, no downloads are saved. + description (str, optional): The description of the agent. Defaults to MultimodalWebSurfer.DEFAULT_DESCRIPTION. + debug_dir (str, optional): The directory where debug information is saved. Defaults to None. + headless (bool, optional): Whether the browser should be headless. Defaults to True. + start_page (str, optional): The start page for the browser. Defaults to MultimodalWebSurfer.DEFAULT_START_PAGE. + animate_actions (bool, optional): Whether to animate actions. Defaults to False. + to_save_screenshots (bool, optional): Whether to save screenshots. Defaults to False. + use_ocr (bool, optional): Whether to use OCR. Defaults to False. + browser_channel (str, optional): The browser channel. Defaults to None. + browser_data_dir (str, optional): The browser data directory. Defaults to None. + to_resize_viewport (bool, optional): Whether to resize the viewport. Defaults to True. + playwright (Playwright, optional): The playwright instance. Defaults to None. + context (BrowserContext, optional): The browser context. Defaults to None. + + + + + Example usage: + + The following example demonstrates how to create a web surfing agent with + a model client and run it for multiple turns. + + .. code-block:: python + + + import asyncio + from agentdhal_agentchat.ui import Console + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer + + + async def main() -> None: + # Define an agent + web_surfer_agent = MultimodalWebSurfer( + name="MultimodalWebSurfer", + model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-08-06"), + ) + + # Define a team + agent_team = RoundRobinGroupChat([web_surfer_agent], max_turns=3) + + # Run the team and stream messages to the console + stream = agent_team.run_stream(task="Navigate to the AutoGen readme on GitHub.") + await Console(stream) + # Close the browser controlled by the agent + await web_surfer_agent.close() + + + asyncio.run(main()) + """ + + component_type = "agent" + component_config_schema = MultimodalWebSurferConfig + component_provider_override = "agentdhal_extensions.agents.web_surfer.MultimodalWebSurfer" + + DEFAULT_DESCRIPTION = """ + A helpful assistant with access to a web browser. + Ask them to perform web searches, open pages, and interact with content (e.g., clicking links, scrolling the viewport, filling in form fields, etc.). + It can also summarize the entire page, or answer questions based on the content of the page. + It can also be asked to sleep and wait for pages to load, in cases where the page seems not yet fully loaded. + """ + DEFAULT_START_PAGE = "https://www.bing.com/" + + # Viewport dimensions + VIEWPORT_HEIGHT = 900 + VIEWPORT_WIDTH = 1440 + + # Size of the image we send to the MLM + # Current values represent a 0.85 scaling to fit within the GPT-4v short-edge constraints (768px) + MLM_HEIGHT = 765 + MLM_WIDTH = 1224 + + SCREENSHOT_TOKENS = 1105 + + def __init__( + self, + name: str, + model_client: ChatCompletionClient, + downloads_folder: str | None = None, + description: str = DEFAULT_DESCRIPTION, + debug_dir: str | None = None, + headless: bool = True, + start_page: str | None = DEFAULT_START_PAGE, + animate_actions: bool = False, + to_save_screenshots: bool = False, + use_ocr: bool = False, + browser_channel: str | None = None, + browser_data_dir: str | None = None, + to_resize_viewport: bool = True, + playwright: Playwright | None = None, + context: BrowserContext | None = None, + ): + """ + Initialize the MultimodalWebSurfer. + """ + super().__init__(name, description) + if debug_dir is None and to_save_screenshots: + raise ValueError( + "Cannot save screenshots without a debug directory. Set it using the 'debug_dir' parameter. The debug directory is created if it does not exist." + ) + if model_client.model_info["function_calling"] is False: + raise ValueError( + "The model does not support function calling. MultimodalWebSurfer requires a model that supports function calling." + ) + + self._model_client = model_client + self.headless = headless + self.browser_channel = browser_channel + self.browser_data_dir = browser_data_dir + self.start_page = start_page or self.DEFAULT_START_PAGE + self.downloads_folder = downloads_folder + self.debug_dir = debug_dir + self.to_save_screenshots = to_save_screenshots + self.use_ocr = use_ocr + self.to_resize_viewport = to_resize_viewport + self.animate_actions = animate_actions + + # Call init to set these in case not set + self._playwright: Playwright | None = playwright + self._context: BrowserContext | None = context + self._page: Page | None = None + self._last_download: Download | None = None + self._prior_metadata_hash: str | None = None + self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.name}.MultimodalWebSurfer") + self._chat_history: List[LLMMessage] = [] + + # Define the download handler + def _download_handler(download: Download) -> None: + self._last_download = download + + self._download_handler = _download_handler + + # Define the Playwright controller that handles the browser interactions + self._playwright_controller = PlaywrightController( + animate_actions=self.animate_actions, + downloads_folder=self.downloads_folder, + viewport_width=self.VIEWPORT_WIDTH, + viewport_height=self.VIEWPORT_HEIGHT, + _download_handler=self._download_handler, + to_resize_viewport=self.to_resize_viewport, + ) + self.default_tools = [ + TOOL_VISIT_URL, + TOOL_WEB_SEARCH, + TOOL_HISTORY_BACK, + TOOL_CLICK, + TOOL_TYPE, + TOOL_READ_PAGE_AND_ANSWER, + TOOL_SUMMARIZE_PAGE, + TOOL_SLEEP, + TOOL_HOVER, + ] + self.did_lazy_init = False # flag to check if we have initialized the browser + + async def _lazy_init( + self, + ) -> None: + """ + On the first call, we initialize the browser and the page. + """ + + # Check the current event loop policy if on windows. + if sys.platform == "win32": + current_policy = asyncio.get_event_loop_policy() + if hasattr(asyncio, "WindowsProactorEventLoopPolicy") and not isinstance( + current_policy, asyncio.WindowsProactorEventLoopPolicy + ): + warnings.warn( + "The current event loop policy is not WindowsProactorEventLoopPolicy. " + "This may cause issues with subprocesses. " + "Try setting the event loop policy to WindowsProactorEventLoopPolicy. " + "For example: `asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())`. " + "See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.ProactorEventLoop.", + stacklevel=2, + ) + + self._last_download = None + self._prior_metadata_hash = None + + # Create the playwright self + launch_args: Dict[str, Any] = {"headless": self.headless} + if self.browser_channel is not None: + launch_args["channel"] = self.browser_channel + if self._playwright is None: + self._playwright = await async_playwright().start() + + # Create the context -- are we launching persistent? + if self._context is None: + if self.browser_data_dir is None: + browser = await self._playwright.chromium.launch(**launch_args) + self._context = await browser.new_context( + user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36 Edg/122.0.0.0" + ) + else: + self._context = await self._playwright.chromium.launch_persistent_context( + self.browser_data_dir, **launch_args + ) + + # Create the page + self._context.set_default_timeout(60000) # One minute + self._page = await self._context.new_page() + assert self._page is not None + # self._page.route(lambda x: True, self._route_handler) + self._page.on("download", self._download_handler) + if self.to_resize_viewport: + await self._page.set_viewport_size({"width": self.VIEWPORT_WIDTH, "height": self.VIEWPORT_HEIGHT}) + await self._page.add_init_script( + path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js") + ) + await self._page.goto(self.start_page) + await self._page.wait_for_load_state() + + # Prepare the debug directory -- which stores the screenshots generated throughout the process + await self._set_debug_dir(self.debug_dir) + self.did_lazy_init = True + + async def close(self) -> None: + """ + Close the browser and the page. + Should be called when the agent is no longer needed. + """ + if self._page is not None: + await self._page.close() + self._page = None + if self._context is not None: + await self._context.close() + self._context = None + if self._playwright is not None: + await self._playwright.stop() + self._playwright = None + + async def _set_debug_dir(self, debug_dir: str | None) -> None: + assert self._page is not None + if self.debug_dir is None: + return + + if not os.path.isdir(self.debug_dir): + os.mkdir(self.debug_dir) + + if self.to_save_screenshots: + current_timestamp = "_" + int(time.time()).__str__() + screenshot_png_name = "screenshot" + current_timestamp + ".png" + + await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore + self.logger.info( + WebSurferEvent( + source=self.name, + url=self._page.url, + message="Screenshot: " + screenshot_png_name, + ) + ) + + @property + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + return (MultiModalMessage,) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + if not self.did_lazy_init: + return + assert self._page is not None + + self._chat_history.clear() + reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page( + self._page, self.start_page + ) + if reset_last_download and self._last_download is not None: + self._last_download = None + if reset_prior_metadata and self._prior_metadata_hash is not None: + self._prior_metadata_hash = None + if self.to_save_screenshots: + current_timestamp = "_" + int(time.time()).__str__() + screenshot_png_name = "screenshot" + current_timestamp + ".png" + + await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore + self.logger.info( + WebSurferEvent( + source=self.name, + url=self._page.url, + message="Screenshot: " + screenshot_png_name, + ) + ) + + self.logger.info( + WebSurferEvent( + source=self.name, + url=self._page.url, + message="Resetting browser.", + ) + ) + + async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: + async for message in self.on_messages_stream(messages, cancellation_token): + if isinstance(message, Response): + return message + raise AssertionError("The stream should have returned the final result.") + + async def on_messages_stream( + self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]: + for chat_message in messages: + self._chat_history.append(chat_message.to_model_message()) + + self.inner_messages: List[BaseAgentEvent | BaseChatMessage] = [] + self.model_usage: List[RequestUsage] = [] + try: + content = await self._generate_reply(cancellation_token=cancellation_token) + self._chat_history.append(AssistantMessage(content=content_to_str(content), source=self.name)) + final_usage = RequestUsage( + prompt_tokens=sum([u.prompt_tokens for u in self.model_usage]), + completion_tokens=sum([u.completion_tokens for u in self.model_usage]), + ) + if isinstance(content, str): + yield Response( + chat_message=TextMessage(content=content, source=self.name, models_usage=final_usage), + inner_messages=self.inner_messages, + ) + else: + yield Response( + chat_message=MultiModalMessage(content=content, source=self.name, models_usage=final_usage), + inner_messages=self.inner_messages, + ) + + except BaseException: + content = f"Web surfing error:\n\n{traceback.format_exc()}" + self._chat_history.append(AssistantMessage(content=content, source=self.name)) + yield Response(chat_message=TextMessage(content=content, source=self.name)) + + async def _generate_reply(self, cancellation_token: CancellationToken) -> UserContent: + """Generates the actual reply. First calls the LLM to figure out which tool to use, then executes the tool.""" + + # Lazy init, initialize the browser and the page on the first generate reply only + if not self.did_lazy_init: + await self._lazy_init() + + assert self._page is not None + + # Clone the messages, removing old screenshots + history: List[LLMMessage] = remove_images(self._chat_history) + + # Split the history, removing the last message + if len(history): + user_request = history.pop() + else: + user_request = UserMessage(content="Empty request.", source="user") + + # Truncate the history for smaller models + if self._model_client.model_info["family"] not in [ + ModelFamily.GPT_4O, + ModelFamily.O1, + ModelFamily.O3, + ModelFamily.GPT_4, + ModelFamily.GPT_35, + ]: + history = [] + + # Ask the page for interactive elements, then prepare the state-of-mark screenshot + rects = await self._playwright_controller.get_interactive_rects(self._page) + viewport = await self._playwright_controller.get_visual_viewport(self._page) + screenshot = await self._page.screenshot() + som_screenshot, visible_rects, rects_above, rects_below = add_set_of_mark(screenshot, rects) + + if self.to_save_screenshots: + current_timestamp = "_" + int(time.time()).__str__() + screenshot_png_name = "screenshot_som" + current_timestamp + ".png" + som_screenshot.save(os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore + self.logger.info( + WebSurferEvent( + source=self.name, + url=self._page.url, + message="Screenshot: " + screenshot_png_name, + ) + ) + # What tools are available? + tools = self.default_tools.copy() + + # We can scroll up + if viewport["pageTop"] > 5: + tools.append(TOOL_SCROLL_UP) + + # Can scroll down + if (viewport["pageTop"] + viewport["height"] + 5) < viewport["scrollHeight"]: + tools.append(TOOL_SCROLL_DOWN) + + # Focus hint + focused = await self._playwright_controller.get_focused_rect_id(self._page) + focused_hint = "" + if focused: + name = self._target_name(focused, rects) + if name: + name = f"(and name '{name}') " + else: + name = "" + + role = "control" + try: + role = rects[focused]["role"] + except KeyError: + pass + + focused_hint = f"\nThe {role} with ID {focused} {name}currently has the input focus.\n\n" + + # Everything visible + visible_targets = "\n".join(self._format_target_list(visible_rects, rects)) + "\n\n" + + # Everything else + other_targets: List[str] = [] + other_targets.extend(self._format_target_list(rects_above, rects)) + other_targets.extend(self._format_target_list(rects_below, rects)) + + if len(other_targets) > 0: + if len(other_targets) > 30: + other_targets = other_targets[0:30] + other_targets.append("...") + other_targets_str = ( + "Additional valid interaction targets include (but are not limited to):\n" + + "\n".join(other_targets) + + "\n\n" + ) + else: + other_targets_str = "" + + state_description = "Your " + await self._get_state_description() + tool_names = "\n".join([t["name"] for t in tools]) + page_title = await self._page.title() + + prompt_message = None + if self._model_client.model_info["vision"]: + text_prompt = WEB_SURFER_TOOL_PROMPT_MM.format( + state_description=state_description, + visible_targets=visible_targets, + other_targets_str=other_targets_str, + focused_hint=focused_hint, + tool_names=tool_names, + title=page_title, + url=self._page.url, + ).strip() + + # Scale the screenshot for the MLM, and close the original + scaled_screenshot = som_screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT)) + som_screenshot.close() + if self.to_save_screenshots: + scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore + + # Create the message + prompt_message = UserMessage( + content=[re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), AGImage.from_pil(scaled_screenshot)], + source=self.name, + ) + else: + text_prompt = WEB_SURFER_TOOL_PROMPT_TEXT.format( + state_description=state_description, + visible_targets=visible_targets, + other_targets_str=other_targets_str, + focused_hint=focused_hint, + tool_names=tool_names, + title=page_title, + url=self._page.url, + ).strip() + + # Create the message + prompt_message = UserMessage(content=re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), source=self.name) + + history.append(prompt_message) + history.append(user_request) + + # {history[-2].content if isinstance(history[-2].content, str) else history[-2].content[0]} + # print(f""" + # ================={len(history)}================= + # {history[-2].content} + # ===== + # {history[-1].content} + # =================================================== + # """) + + # Make the request + response = await self._model_client.create( + history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token + ) # , "parallel_tool_calls": False}) + + self.model_usage.append(response.usage) + message = response.content + self._last_download = None + if isinstance(message, str): + # Answer directly + self.inner_messages.append(TextMessage(content=message, source=self.name)) + return message + elif isinstance(message, list): + # Take an action + return await self._execute_tool(message, rects, tool_names, cancellation_token=cancellation_token) + else: + # Not sure what happened here + raise AssertionError(f"Unknown response format '{message}'") + + async def _execute_tool( + self, + message: List[FunctionCall], + rects: Dict[str, InteractiveRegion], + tool_names: str, + cancellation_token: Optional[CancellationToken] = None, + ) -> UserContent: + # Execute the tool + name = message[0].name + args = json.loads(message[0].arguments) + action_description = "" + assert self._page is not None + self.logger.info( + WebSurferEvent( + source=self.name, + url=self._page.url, + action=name, + arguments=args, + message=f"{name}( {json.dumps(args)} )", + ) + ) + self.inner_messages.append(TextMessage(content=f"{name}( {json.dumps(args)} )", source=self.name)) + + if name == "visit_url": + url = args.get("url") + action_description = f"I typed '{url}' into the browser address bar." + # Check if the argument starts with a known protocol + if url.startswith(("https://", "http://", "file://", "about:")): + reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page( + self._page, url + ) + # If the argument contains a space, treat it as a search query + elif " " in url: + reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page( + self._page, f"https://www.bing.com/search?q={quote_plus(url)}&FORM=QBLH" + ) + # Otherwise, prefix with https:// + else: + reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page( + self._page, "https://" + url + ) + if reset_last_download and self._last_download is not None: + self._last_download = None + if reset_prior_metadata and self._prior_metadata_hash is not None: + self._prior_metadata_hash = None + elif name == "history_back": + action_description = "I clicked the browser back button." + await self._playwright_controller.back(self._page) + + elif name == "web_search": + query = args.get("query") + action_description = f"I typed '{query}' into the browser search bar." + reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page( + self._page, f"https://www.bing.com/search?q={quote_plus(query)}&FORM=QBLH" + ) + if reset_last_download and self._last_download is not None: + self._last_download = None + if reset_prior_metadata and self._prior_metadata_hash is not None: + self._prior_metadata_hash = None + elif name == "scroll_up": + action_description = "I scrolled up one page in the browser." + await self._playwright_controller.page_up(self._page) + elif name == "scroll_down": + action_description = "I scrolled down one page in the browser." + await self._playwright_controller.page_down(self._page) + + elif name == "click": + target_id = str(args.get("target_id")) + target_name = self._target_name(target_id, rects) + if target_name: + action_description = f"I clicked '{target_name}'." + else: + action_description = "I clicked the control." + new_page_tentative = await self._playwright_controller.click_id(self._page, target_id) + if new_page_tentative is not None: + self._page = new_page_tentative + self._prior_metadata_hash = None + self.logger.info( + WebSurferEvent( + source=self.name, + url=self._page.url, + message="New tab or window.", + ) + ) + elif name == "input_text": + input_field_id = str(args.get("input_field_id")) + text_value = str(args.get("text_value")) + input_field_name = self._target_name(input_field_id, rects) + if input_field_name: + action_description = f"I typed '{text_value}' into '{input_field_name}'." + else: + action_description = f"I input '{text_value}'." + await self._playwright_controller.fill_id(self._page, input_field_id, text_value) + + elif name == "scroll_element_up": + target_id = str(args.get("target_id")) + target_name = self._target_name(target_id, rects) + + if target_name: + action_description = f"I scrolled '{target_name}' up." + else: + action_description = "I scrolled the control up." + + await self._playwright_controller.scroll_id(self._page, target_id, "up") + + elif name == "scroll_element_down": + target_id = str(args.get("target_id")) + target_name = self._target_name(target_id, rects) + + if target_name: + action_description = f"I scrolled '{target_name}' down." + else: + action_description = "I scrolled the control down." + + await self._playwright_controller.scroll_id(self._page, target_id, "down") + + elif name == "answer_question": + question = str(args.get("question")) + action_description = f"I answered the following question '{question}' based on the web page." + # Do Q&A on the DOM. No need to take further action. Browser state does not change. + return await self._summarize_page(question=question, cancellation_token=cancellation_token) + elif name == "summarize_page": + # Summarize the DOM. No need to take further action. Browser state does not change. + action_description = "I summarized the current web page" + return await self._summarize_page(cancellation_token=cancellation_token) + + elif name == "hover": + target_id = str(args.get("target_id")) + target_name = self._target_name(target_id, rects) + if target_name: + action_description = f"I hovered over '{target_name}'." + else: + action_description = "I hovered over the control." + await self._playwright_controller.hover_id(self._page, target_id) + + elif name == "sleep": + action_description = "I am waiting a short period of time before taking further action." + await self._playwright_controller.sleep(self._page, 3) + + else: + raise ValueError(f"Unknown tool '{name}'. Please choose from:\n\n{tool_names}") + + await self._page.wait_for_load_state() + await self._playwright_controller.sleep(self._page, 3) + + # Handle downloads + if self._last_download is not None and self.downloads_folder is not None: + fname = os.path.join(self.downloads_folder, self._last_download.suggested_filename) + await self._last_download.save_as(fname) # type: ignore + page_body = f"Download Successful

Successfully downloaded '{self._last_download.suggested_filename}' to local path:

{fname}

" + await self._page.goto( + "data:text/html;base64," + base64.b64encode(page_body.encode("utf-8")).decode("utf-8") + ) + await self._page.wait_for_load_state() + + # Handle metadata + page_metadata = json.dumps(await self._playwright_controller.get_page_metadata(self._page), indent=4) + metadata_hash = hashlib.md5(page_metadata.encode("utf-8")).hexdigest() + if metadata_hash != self._prior_metadata_hash: + page_metadata = ( + "\n\nThe following metadata was extracted from the webpage:\n\n" + page_metadata.strip() + "\n" + ) + else: + page_metadata = "" + self._prior_metadata_hash = metadata_hash + + new_screenshot = await self._page.screenshot() + if self.to_save_screenshots: + current_timestamp = "_" + int(time.time()).__str__() + screenshot_png_name = "screenshot" + current_timestamp + ".png" + + async with aiofiles.open(os.path.join(self.debug_dir, screenshot_png_name), "wb") as file: # type: ignore + await file.write(new_screenshot) # type: ignore + self.logger.info( + WebSurferEvent( + source=self.name, + url=self._page.url, + message="Screenshot: " + screenshot_png_name, + ) + ) + + # Return the complete observation + state_description = "The " + await self._get_state_description() + message_content = ( + f"{action_description}\n\n" + state_description + page_metadata + "\nHere is a screenshot of the page." + ) + + return [ + re.sub(r"(\n\s*){3,}", "\n\n", message_content), # Removing blank lines + AGImage.from_pil(PIL.Image.open(io.BytesIO(new_screenshot))), + ] + + async def _get_state_description(self) -> str: + assert self._playwright_controller is not None + assert self._page is not None + + # Describe the viewport of the new page in words + viewport = await self._playwright_controller.get_visual_viewport(self._page) + percent_visible = int(viewport["height"] * 100 / viewport["scrollHeight"]) + percent_scrolled = int(viewport["pageTop"] * 100 / viewport["scrollHeight"]) + if percent_scrolled < 1: # Allow some rounding error + position_text = "at the top of the page" + elif percent_scrolled + percent_visible >= 99: # Allow some rounding error + position_text = "at the bottom of the page" + else: + position_text = str(percent_scrolled) + "% down from the top of the page" + + visible_text = await self._playwright_controller.get_visible_text(self._page) + + # Return the complete observation + page_title = await self._page.title() + message_content = f"web browser is open to the page [{page_title}]({self._page.url}).\nThe viewport shows {percent_visible}% of the webpage, and is positioned {position_text}\n" + message_content += f"The following text is visible in the viewport:\n\n{visible_text}" + return message_content + + def _target_name(self, target: str, rects: Dict[str, InteractiveRegion]) -> str | None: + try: + return rects[target]["aria_name"].strip() + except KeyError: + return None + + def _format_target_list(self, ids: List[str], rects: Dict[str, InteractiveRegion]) -> List[str]: + """ + Format the list of targets in the webpage as a string to be used in the agent's prompt. + """ + targets: List[str] = [] + for r in list(set(ids)): + if r in rects: + # Get the role + aria_role = rects[r].get("role", "").strip() + if len(aria_role) == 0: + aria_role = rects[r].get("tag_name", "").strip() + + # Get the name + aria_name = re.sub(r"[\n\r]+", " ", rects[r].get("aria_name", "")).strip() + + # What are the actions? + actions = ['"click", "hover"'] + if rects[r]["role"] in ["textbox", "searchbox", "search"]: + actions = ['"input_text"'] + actions_str = "[" + ",".join(actions) + "]" + + targets.append(f'{{"id": {r}, "name": "{aria_name}", "role": "{aria_role}", "tools": {actions_str} }}') + + return targets + + async def _summarize_page( + self, + question: str | None = None, + cancellation_token: Optional[CancellationToken] = None, + ) -> str: + assert self._page is not None + + page_markdown: str = await self._playwright_controller.get_page_markdown(self._page) + + title: str = self._page.url + try: + title = await self._page.title() + except Exception: + pass + + # Take a screenshot and scale it + screenshot = Image.open(io.BytesIO(await self._page.screenshot())) + scaled_screenshot = screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT)) + screenshot.close() + ag_image = AGImage.from_pil(scaled_screenshot) + + # Prepare the system prompt + messages: List[LLMMessage] = [] + messages.append(SystemMessage(content=WEB_SURFER_QA_SYSTEM_MESSAGE)) + prompt = WEB_SURFER_QA_PROMPT(title, question) + # Grow the buffer (which is added to the prompt) until we overflow the context window or run out of lines + buffer = "" + # for line in re.split(r"([\r\n]+)", page_markdown): + for line in page_markdown.splitlines(): + trial_message = UserMessage( + content=prompt + buffer + line, + source=self.name, + ) + + try: + remaining = self._model_client.remaining_tokens(messages + [trial_message]) + except KeyError: + # Use the default if the model isn't found + remaining = DEFAULT_CONTEXT_SIZE - self._model_client.count_tokens(messages + [trial_message]) + + if self._model_client.model_info["vision"] and remaining <= 0: + break + + if self._model_client.model_info["vision"] and remaining <= self.SCREENSHOT_TOKENS: + break + + buffer += line + + # Nothing to do + buffer = buffer.strip() + if len(buffer) == 0: + return "Nothing to summarize." + + # Append the message + if self._model_client.model_info["vision"]: + # Multimodal + messages.append( + UserMessage( + content=[ + prompt + buffer, + ag_image, + ], + source=self.name, + ) + ) + else: + # Text only + messages.append( + UserMessage( + content=prompt + buffer, + source=self.name, + ) + ) + + # Generate the response + response = await self._model_client.create(messages, cancellation_token=cancellation_token) + self.model_usage.append(response.usage) + scaled_screenshot.close() + assert isinstance(response.content, str) + return response.content + + def _to_config(self) -> MultimodalWebSurferConfig: + return MultimodalWebSurferConfig( + name=self.name, + model_client=self._model_client.dump_component(), + downloads_folder=self.downloads_folder, + description=self.description, + debug_dir=self.debug_dir, + headless=self.headless, + start_page=self.start_page, + animate_actions=self.animate_actions, + to_save_screenshots=self.to_save_screenshots, + use_ocr=self.use_ocr, + browser_channel=self.browser_channel, + browser_data_dir=self.browser_data_dir, + to_resize_viewport=self.to_resize_viewport, + ) + + @classmethod + def _from_config(cls, config: MultimodalWebSurferConfig) -> Self: + return cls( + name=config.name, + model_client=ChatCompletionClient.load_component(config.model_client), + downloads_folder=config.downloads_folder, + description=config.description or cls.DEFAULT_DESCRIPTION, + debug_dir=config.debug_dir, + headless=config.headless, + start_page=config.start_page or cls.DEFAULT_START_PAGE, + animate_actions=config.animate_actions, + to_save_screenshots=config.to_save_screenshots, + use_ocr=config.use_ocr, + browser_channel=config.browser_channel, + browser_data_dir=config.browser_data_dir, + to_resize_viewport=config.to_resize_viewport, + ) diff --git a/agent_dhal/agentdhal_extensions/agents/web_surfer/_prompts.py b/agent_dhal/agentdhal_extensions/agents/web_surfer/_prompts.py new file mode 100644 index 0000000..d1f1885 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/web_surfer/_prompts.py @@ -0,0 +1,52 @@ +WEB_SURFER_TOOL_PROMPT_MM = """ +{state_description} + +Consider the following screenshot of the page. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below: + +{visible_targets}{other_targets_str}{focused_hint} + +You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible: + +{tool_names} + +When deciding between tools, consider if the request can be best addressed by: + - the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate) + - contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate + - on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option) + +My request follows: +""" + +WEB_SURFER_TOOL_PROMPT_TEXT = """ +{state_description} + +You have also identified the following interactive components: + +{visible_targets}{other_targets_str}{focused_hint} + +You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible: + +{tool_names} + +When deciding between tools, consider if the request can be best addressed by: + - the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate) + - contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate + - on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option) + +My request follows: +""" + + +WEB_SURFER_QA_SYSTEM_MESSAGE = """ +You are a helpful assistant that can summarize long documents to answer question. +""" + + +def WEB_SURFER_QA_PROMPT(title: str, question: str | None = None) -> str: + base_prompt = f"We are visiting the webpage '{title}'. Its full-text content are pasted below, along with a screenshot of the page's current viewport." + if question is not None: + return ( + f"{base_prompt} Please summarize the webpage into one or two paragraphs with respect to '{question}':\n\n" + ) + else: + return f"{base_prompt} Please summarize the webpage into one or two paragraphs:\n\n" diff --git a/agent_dhal/agentdhal_extensions/agents/web_surfer/_set_of_mark.py b/agent_dhal/agentdhal_extensions/agents/web_surfer/_set_of_mark.py new file mode 100644 index 0000000..07656ce --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/web_surfer/_set_of_mark.py @@ -0,0 +1,96 @@ +import io +import random +from typing import BinaryIO, Dict, List, Tuple, cast + +from PIL import Image, ImageDraw, ImageFont + +from ._types import DOMRectangle, InteractiveRegion + +TOP_NO_LABEL_ZONE = 20 # Don't print any labels close the top of the page + + +def add_set_of_mark( + screenshot: bytes | Image.Image | io.BufferedIOBase, ROIs: Dict[str, InteractiveRegion] +) -> Tuple[Image.Image, List[str], List[str], List[str]]: + if isinstance(screenshot, Image.Image): + return _add_set_of_mark(screenshot, ROIs) + + if isinstance(screenshot, bytes): + screenshot = io.BytesIO(screenshot) + + # TODO: Not sure why this cast was needed, but by this point screenshot is a binary file-like object + image = Image.open(cast(BinaryIO, screenshot)) + comp, visible_rects, rects_above, rects_below = _add_set_of_mark(image, ROIs) + image.close() + return comp, visible_rects, rects_above, rects_below + + +def _add_set_of_mark( + screenshot: Image.Image, ROIs: Dict[str, InteractiveRegion] +) -> Tuple[Image.Image, List[str], List[str], List[str]]: + visible_rects: List[str] = list() + rects_above: List[str] = list() # Scroll up to see + rects_below: List[str] = list() # Scroll down to see + + fnt = ImageFont.load_default(14) + base = screenshot.convert("L").convert("RGBA") + overlay = Image.new("RGBA", base.size) + + draw = ImageDraw.Draw(overlay) + for r in ROIs: + for rect in ROIs[r]["rects"]: + # Empty rectangles + if not rect: + continue + if rect["width"] * rect["height"] == 0: + continue + + mid = ((rect["right"] + rect["left"]) / 2.0, (rect["top"] + rect["bottom"]) / 2.0) + + if 0 <= mid[0] and mid[0] < base.size[0]: + if mid[1] < 0: + rects_above.append(r) + elif mid[1] >= base.size[1]: + rects_below.append(r) + else: + visible_rects.append(r) + _draw_roi(draw, int(r), fnt, rect) + + comp = Image.alpha_composite(base, overlay) + overlay.close() + return comp, visible_rects, rects_above, rects_below + + +def _draw_roi( + draw: ImageDraw.ImageDraw, idx: int, font: ImageFont.FreeTypeFont | ImageFont.ImageFont, rect: DOMRectangle +) -> None: + color = _color(idx) + luminance = color[0] * 0.3 + color[1] * 0.59 + color[2] * 0.11 + text_color = (0, 0, 0, 255) if luminance > 90 else (255, 255, 255, 255) + + roi = ((rect["left"], rect["top"]), (rect["right"], rect["bottom"])) + + label_location = (rect["right"], rect["top"]) + label_anchor = "rb" + + if label_location[1] <= TOP_NO_LABEL_ZONE: + label_location = (rect["right"], rect["bottom"]) + label_anchor = "rt" + + draw.rectangle(roi, outline=color, fill=(color[0], color[1], color[2], 48), width=2) + + # TODO: Having trouble with these types being partially Unknown. + bbox = draw.textbbox(label_location, str(idx), font=font, anchor=label_anchor, align="center") # type: ignore + bbox = (bbox[0] - 3, bbox[1] - 3, bbox[2] + 3, bbox[3] + 3) + draw.rectangle(bbox, fill=color) + + # TODO: Having trouble with these types being partially Unknown. + draw.text(label_location, str(idx), fill=text_color, font=font, anchor=label_anchor, align="center") # type: ignore + + +def _color(identifier: int) -> Tuple[int, int, int, int]: + rnd = random.Random(int(identifier)) + color = [rnd.randint(0, 255), rnd.randint(125, 255), rnd.randint(0, 50)] + rnd.shuffle(color) + color.append(255) + return cast(Tuple[int, int, int, int], tuple(color)) diff --git a/agent_dhal/agentdhal_extensions/agents/web_surfer/_tool_definitions.py b/agent_dhal/agentdhal_extensions/agents/web_surfer/_tool_definitions.py new file mode 100644 index 0000000..04d530e --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/web_surfer/_tool_definitions.py @@ -0,0 +1,317 @@ +from typing import Any, Dict + +from agentdhal_core.tools._base import ParametersSchema, ToolSchema + + +def _load_tool(tooldef: Dict[str, Any]) -> ToolSchema: + return ToolSchema( + name=tooldef["function"]["name"], + description=tooldef["function"]["description"], + parameters=ParametersSchema( + type="object", + properties=tooldef["function"]["parameters"]["properties"], + required=tooldef["function"]["parameters"]["required"], + ), + ) + + +REASONING_TOOL_PROMPT = ( + "A short description of the action to be performed and reason for doing so, do not mention the user." +) + +TOOL_VISIT_URL: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "visit_url", + "description": "Navigate directly to a provided URL using the browser's address bar. Prefer this tool over other navigation techniques in cases where the user provides a fully-qualified URL (e.g., choose it over clicking links, or inputing queries into search boxes).", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + "url": { + "type": "string", + "description": "The URL to visit in the browser.", + }, + }, + "required": ["reasoning", "url"], + }, + }, + } +) + +TOOL_WEB_SEARCH: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "web_search", + "description": "Performs a web search on Bing.com with the given query.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + "query": { + "type": "string", + "description": "The web search query to use.", + }, + }, + "required": ["reasoning", "query"], + }, + }, + } +) + +TOOL_HISTORY_BACK: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "history_back", + "description": "Navigates back one page in the browser's history. This is equivalent to clicking the browser back button.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + }, + "required": ["reasoning"], + }, + }, + } +) + +TOOL_SCROLL_UP: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "scroll_up", + "description": "Scrolls the entire browser viewport one page UP towards the beginning.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + }, + "required": ["reasoning"], + }, + }, + } +) + +TOOL_SCROLL_DOWN: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "scroll_down", + "description": "Scrolls the entire browser viewport one page DOWN towards the end.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + }, + "required": ["reasoning"], + }, + }, + } +) + +TOOL_CLICK: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "click", + "description": "Clicks the mouse on the target with the given id.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + "target_id": { + "type": "integer", + "description": "The numeric id of the target to click.", + }, + }, + "required": ["reasoning", "target_id"], + }, + }, + } +) + +TOOL_TYPE: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "input_text", + "description": "Types the given text value into the specified field.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + "input_field_id": { + "type": "integer", + "description": "The numeric id of the input field to receive the text.", + }, + "text_value": { + "type": "string", + "description": "The text to type into the input field.", + }, + }, + "required": ["reasoning", "input_field_id", "text_value"], + }, + }, + } +) + +TOOL_SCROLL_ELEMENT_DOWN: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "scroll_element_down", + "description": "Scrolls a given html element (e.g., a div or a menu) DOWN.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + "target_id": { + "type": "integer", + "description": "The numeric id of the target to scroll down.", + }, + }, + "required": ["reasoning", "target_id"], + }, + }, + } +) + +TOOL_SCROLL_ELEMENT_UP: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "scroll_element_up", + "description": "Scrolls a given html element (e.g., a div or a menu) UP.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + "target_id": { + "type": "integer", + "description": "The numeric id of the target to scroll UP.", + }, + }, + "required": ["reasoning", "target_id"], + }, + }, + } +) + +TOOL_HOVER: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "hover", + "description": "Hovers the mouse over the target with the given id.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + "target_id": { + "type": "integer", + "description": "The numeric id of the target to hover over.", + }, + }, + "required": ["reasoning", "target_id"], + }, + }, + } +) + + +TOOL_READ_PAGE_AND_ANSWER: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "answer_question", + "description": "Uses AI to answer a question about the current webpage's content.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + "question": { + "type": "string", + "description": "The question to answer.", + }, + }, + "required": ["reasoning", "question"], + }, + }, + } +) + +TOOL_SUMMARIZE_PAGE: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "summarize_page", + "description": "Uses AI to summarize the entire page.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + }, + "required": ["reasoning"], + }, + }, + } +) + +TOOL_SLEEP: ToolSchema = _load_tool( + { + "type": "function", + "function": { + "name": "sleep", + "description": "Wait a short period of time. Call this function if the page has not yet fully loaded, or if it is determined that a small delay would increase the task's chances of success.", + "parameters": { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": REASONING_TOOL_PROMPT, + }, + }, + "required": ["reasoning"], + }, + }, + } +) diff --git a/agent_dhal/agentdhal_extensions/agents/web_surfer/_types.py b/agent_dhal/agentdhal_extensions/agents/web_surfer/_types.py new file mode 100644 index 0000000..a3ebae1 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/web_surfer/_types.py @@ -0,0 +1,106 @@ +from typing import Any, Dict, List, TypedDict, Union + +from agentdhal_core import FunctionCall, Image +from agentdhal_core.models import FunctionExecutionResult + +UserContent = Union[str, List[Union[str, Image]]] +AssistantContent = Union[str, List[FunctionCall]] +FunctionExecutionContent = List[FunctionExecutionResult] +SystemContent = str + + +class DOMRectangle(TypedDict): + x: Union[int, float] + y: Union[int, float] + width: Union[int, float] + height: Union[int, float] + top: Union[int, float] + right: Union[int, float] + bottom: Union[int, float] + left: Union[int, float] + + +class VisualViewport(TypedDict): + height: Union[int, float] + width: Union[int, float] + offsetLeft: Union[int, float] + offsetTop: Union[int, float] + pageLeft: Union[int, float] + pageTop: Union[int, float] + scale: Union[int, float] + clientWidth: Union[int, float] + clientHeight: Union[int, float] + scrollWidth: Union[int, float] + scrollHeight: Union[int, float] + + +class InteractiveRegion(TypedDict): + tag_name: str + role: str + aria_name: str + v_scrollable: bool + rects: List[DOMRectangle] + + +# Helper functions for dealing with JSON. Not sure there's a better way? + + +def _get_str(d: Any, k: str) -> str: + val = d[k] + assert isinstance(val, str) + return val + + +def _get_number(d: Any, k: str) -> Union[int, float]: + val = d[k] + assert isinstance(val, int) or isinstance(val, float) + return val + + +def _get_bool(d: Any, k: str) -> bool: + val = d[k] + assert isinstance(val, bool) + return val + + +def domrectangle_from_dict(rect: Dict[str, Any]) -> DOMRectangle: + return DOMRectangle( + x=_get_number(rect, "x"), + y=_get_number(rect, "y"), + width=_get_number(rect, "width"), + height=_get_number(rect, "height"), + top=_get_number(rect, "top"), + right=_get_number(rect, "right"), + bottom=_get_number(rect, "bottom"), + left=_get_number(rect, "left"), + ) + + +def interactiveregion_from_dict(region: Dict[str, Any]) -> InteractiveRegion: + typed_rects: List[DOMRectangle] = [] + for rect in region["rects"]: + typed_rects.append(domrectangle_from_dict(rect)) + + return InteractiveRegion( + tag_name=_get_str(region, "tag_name"), + role=_get_str(region, "role"), + aria_name=_get_str(region, "aria-name"), + v_scrollable=_get_bool(region, "v-scrollable"), + rects=typed_rects, + ) + + +def visualviewport_from_dict(viewport: Dict[str, Any]) -> VisualViewport: + return VisualViewport( + height=_get_number(viewport, "height"), + width=_get_number(viewport, "width"), + offsetLeft=_get_number(viewport, "offsetLeft"), + offsetTop=_get_number(viewport, "offsetTop"), + pageLeft=_get_number(viewport, "pageLeft"), + pageTop=_get_number(viewport, "pageTop"), + scale=_get_number(viewport, "scale"), + clientWidth=_get_number(viewport, "clientWidth"), + clientHeight=_get_number(viewport, "clientHeight"), + scrollWidth=_get_number(viewport, "scrollWidth"), + scrollHeight=_get_number(viewport, "scrollHeight"), + ) diff --git a/agent_dhal/agentdhal_extensions/agents/web_surfer/page_script.js b/agent_dhal/agentdhal_extensions/agents/web_surfer/page_script.js new file mode 100644 index 0000000..1363e83 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/web_surfer/page_script.js @@ -0,0 +1,429 @@ +var MultimodalWebSurfer = MultimodalWebSurfer || (function() { + let nextLabel = 10; + + let roleMapping = { + "a": "link", + "area": "link", + "button": "button", + "input, type=button": "button", + "input, type=checkbox": "checkbox", + "input, type=email": "textbox", + "input, type=number": "spinbutton", + "input, type=radio": "radio", + "input, type=range": "slider", + "input, type=reset": "button", + "input, type=search": "searchbox", + "input, type=submit": "button", + "input, type=tel": "textbox", + "input, type=text": "textbox", + "input, type=url": "textbox", + "search": "search", + "select": "combobox", + "option": "option", + "textarea": "textbox" + }; + + let getCursor = function(elm) { + return window.getComputedStyle(elm)["cursor"]; + }; + + let getInteractiveElements = function() { + + let results = [] + let roles = ["scrollbar", "searchbox", "slider", "spinbutton", "switch", "tab", "treeitem", "button", "checkbox", "gridcell", "link", "menuitem", "menuitemcheckbox", "menuitemradio", "option", "progressbar", "radio", "textbox", "combobox", "menu", "tree", "treegrid", "grid", "listbox", "radiogroup", "widget"]; + let inertCursors = ["auto", "default", "none", "text", "vertical-text", "not-allowed", "no-drop"]; + + // Get the main interactive elements + let nodeList = document.querySelectorAll("input, select, textarea, button, [href], [onclick], [contenteditable], [tabindex]:not([tabindex='-1'])"); + for (let i=0; i -1) { + results.push(nodeList[i]); + } + } + } + + // Any element that changes the cursor to something implying interactivity + nodeList = document.querySelectorAll("*"); + for (let i=0; i= 0) { + continue; + } + + // Move up to the first instance of this cursor change + parent = node.parentNode; + while (parent && getCursor(parent) == cursor) { + node = parent; + parent = node.parentNode; + } + + // Add the node if it is new + if (results.indexOf(node) == -1) { + results.push(node); + } + } + + return results; + }; + + let labelElements = function(elements) { + for (let i=0; i= 1; + + let record = { + "tag_name": ariaRole[1], + "role": ariaRole[0], + "aria-name": ariaName, + "v-scrollable": vScrollable, + "rects": [] + }; + + for (const rect of rects) { + let x = rect.left + rect.width/2; + let y = rect.top + rect.height/2; + if (isTopmost(elements[i], x, y)) { + record["rects"].push(JSON.parse(JSON.stringify(rect))); + } + } + + if (record["rects"].length > 0) { + results[key] = record; + } + } + return results; + }; + + let getVisualViewport = function() { + let vv = window.visualViewport; + let de = document.documentElement; + return { + "height": vv ? vv.height : 0, + "width": vv ? vv.width : 0, + "offsetLeft": vv ? vv.offsetLeft : 0, + "offsetTop": vv ? vv.offsetTop : 0, + "pageLeft": vv ? vv.pageLeft : 0, + "pageTop": vv ? vv.pageTop : 0, + "scale": vv ? vv.scale : 0, + "clientWidth": de ? de.clientWidth : 0, + "clientHeight": de ? de.clientHeight : 0, + "scrollWidth": de ? de.scrollWidth : 0, + "scrollHeight": de ? de.scrollHeight : 0 + }; + }; + + let _getMetaTags = function() { + let meta = document.querySelectorAll("meta"); + let results = {}; + for (let i = 0; i { + addValue(information, propName, childInfo); + }); + } + + } else if (child.hasAttribute('itemprop')) { + const itemProp = child.getAttribute('itemprop'); + itemProp.split(' ').forEach(propName => { + if (propName === 'url') { + addValue(information, propName, child.href); + } else { + addValue(information, propName, sanitize(child.getAttribute("content") || child.content || child.textContent || child.src || "")); + } + }); + traverseItem(child, information); + } else { + traverseItem(child, information); + } + } + } + + const microdata = []; + + document.querySelectorAll("[itemscope]").forEach(function(elem, i) { + const itemType = elem.getAttribute('itemtype'); + const information = { + itemType: itemType + }; + traverseItem(elem, information); + microdata.push(information); + }); + + return microdata; + }; + + let getPageMetadata = function() { + let jsonld = _getJsonLd(); + let metaTags = _getMetaTags(); + let microdata = _getMicrodata(); + let results = {} + if (jsonld.length > 0) { + try { + results["jsonld"] = JSON.parse(jsonld); + } + catch (e) { + results["jsonld"] = jsonld; + } + } + if (microdata.length > 0) { + results["microdata"] = microdata; + } + for (let key in metaTags) { + if (metaTags.hasOwnProperty(key)) { + results["meta_tags"] = metaTags; + break; + } + } + return results; + }; + + + let getVisibleText = function() { + // Get the window’s current viewport boundaries + const viewportHeight = window.innerHeight || document.documentElement.clientHeight; + const viewportWidth = window.innerWidth || document.documentElement.clientWidth; + + let textInView = ""; + const walker = document.createTreeWalker( + document.body, + NodeFilter.SHOW_TEXT, + null, + false + ); + + while (walker.nextNode()) { + const textNode = walker.currentNode; + // Create a range to retrieve bounding rectangles of the current text node + const range = document.createRange(); + range.selectNodeContents(textNode); + + const rects = range.getClientRects(); + + // Check if any rect is inside (or partially inside) the viewport + for (const rect of rects) { + const isVisible = + rect.width > 0 && + rect.height > 0 && + rect.bottom >= 0 && + rect.right >= 0 && + rect.top <= viewportHeight && + rect.left <= viewportWidth; + + if (isVisible) { + textInView += textNode.nodeValue.replace(/\s+/g, " "); + // Is the parent a block element? + if (textNode.parentNode) { + const parent = textNode.parentNode; + const style = window.getComputedStyle(parent); + if (["inline", "hidden", "none"].indexOf(style.display) === -1) { + textInView += "\n"; + } + } + break; // No need to check other rects once found visible + } + } + } + + // Remove blank lines from textInView + textInView = textInView.replace(/^\s*\n/gm, "").trim().replace(/\n+/g, "\n"); + return textInView; + }; + + return { + getInteractiveRects: getInteractiveRects, + getVisualViewport: getVisualViewport, + getFocusedElementId: getFocusedElementId, + getPageMetadata: getPageMetadata, + getVisibleText: getVisibleText, + }; +})(); diff --git a/agent_dhal/agentdhal_extensions/agents/web_surfer/playwright_controller.py b/agent_dhal/agentdhal_extensions/agents/web_surfer/playwright_controller.py new file mode 100644 index 0000000..90a830b --- /dev/null +++ b/agent_dhal/agentdhal_extensions/agents/web_surfer/playwright_controller.py @@ -0,0 +1,578 @@ +import asyncio +import base64 +import io +import os +import random +import warnings +from types import ModuleType +from typing import Any, Callable, Dict, Optional, Tuple, Union, cast + +from playwright._impl._errors import Error as PlaywrightError +from playwright._impl._errors import TimeoutError +from playwright.async_api import Download, Page + +from ._types import ( + InteractiveRegion, + VisualViewport, + interactiveregion_from_dict, + visualviewport_from_dict, +) + +markitdown: ModuleType | None = None +try: + # Suppress warnings from markitdown -- which is pretty chatty + warnings.filterwarnings(action="ignore", module="markitdown") + import markitdown +except ImportError: + pass + + +class PlaywrightController: + """ + A helper class to allow Playwright to interact with web pages to perform actions such as clicking, filling, and scrolling. + + Args: + downloads_folder (str | None): The folder to save downloads to. If None, downloads are not saved. + animate_actions (bool): Whether to animate the actions (create fake cursor to click). + viewport_width (int): The width of the viewport. + viewport_height (int): The height of the viewport. + _download_handler (Optional[Callable[[Download], None]]): A function to handle downloads. + to_resize_viewport (bool): Whether to resize the viewport + """ + + def __init__( + self, + downloads_folder: str | None = None, + animate_actions: bool = False, + viewport_width: int = 1440, + viewport_height: int = 900, + _download_handler: Optional[Callable[[Download], None]] = None, + to_resize_viewport: bool = True, + ) -> None: + """ + Initialize the PlaywrightController. + """ + assert isinstance(animate_actions, bool) + assert isinstance(viewport_width, int) + assert isinstance(viewport_height, int) + assert viewport_height > 0 + assert viewport_width > 0 + + self.animate_actions = animate_actions + self.downloads_folder = downloads_folder + self.viewport_width = viewport_width + self.viewport_height = viewport_height + self._download_handler = _download_handler + self.to_resize_viewport = to_resize_viewport + self._page_script: str = "" + self.last_cursor_position: Tuple[float, float] = (0.0, 0.0) + self._markdown_converter: Optional[Any] | None = None + + # Read page_script + with open( + os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"), "rt", encoding="utf-8" + ) as fh: + self._page_script = fh.read() + + async def sleep(self, page: Page, duration: Union[int, float]) -> None: + """ + Pause the execution for a specified duration. + + Args: + page (Page): The Playwright page object. + duration (Union[int, float]): The duration to sleep in milliseconds. + """ + assert page is not None + await page.wait_for_timeout(duration * 1000) + + async def get_interactive_rects(self, page: Page) -> Dict[str, InteractiveRegion]: + """ + Retrieve interactive regions from the web page. + + Args: + page (Page): The Playwright page object. + + Returns: + Dict[str, InteractiveRegion]: A dictionary of interactive regions. + """ + assert page is not None + # Read the regions from the DOM + try: + await page.evaluate(self._page_script) + except Exception: + pass + result = cast(Dict[str, Dict[str, Any]], await page.evaluate("MultimodalWebSurfer.getInteractiveRects();")) + + # Convert the results into appropriate types + assert isinstance(result, dict) + typed_results: Dict[str, InteractiveRegion] = {} + for k in result: + assert isinstance(k, str) + typed_results[k] = interactiveregion_from_dict(result[k]) + + return typed_results + + async def get_visual_viewport(self, page: Page) -> VisualViewport: + """ + Retrieve the visual viewport of the web page. + + Args: + page (Page): The Playwright page object. + + Returns: + VisualViewport: The visual viewport of the page. + """ + assert page is not None + try: + await page.evaluate(self._page_script) + except Exception: + pass + return visualviewport_from_dict(await page.evaluate("MultimodalWebSurfer.getVisualViewport();")) + + async def get_focused_rect_id(self, page: Page) -> str | None: + """ + Retrieve the ID of the currently focused element. + + Args: + page (Page): The Playwright page object. + + Returns: + str: The ID of the focused element or None if no control has focus. + """ + assert page is not None + try: + await page.evaluate(self._page_script) + except Exception: + pass + result = await page.evaluate("MultimodalWebSurfer.getFocusedElementId();") + return None if result is None else str(result) + + async def get_page_metadata(self, page: Page) -> Dict[str, Any]: + """ + Retrieve metadata from the web page. + + Args: + page (Page): The Playwright page object. + + Returns: + Dict[str, Any]: A dictionary of page metadata. + """ + assert page is not None + try: + await page.evaluate(self._page_script) + except Exception: + pass + result = await page.evaluate("MultimodalWebSurfer.getPageMetadata();") + assert isinstance(result, dict) + return cast(Dict[str, Any], result) + + async def on_new_page(self, page: Page) -> None: + """ + Handle actions to perform on a new page. + + Args: + page (Page): The Playwright page object. + """ + assert page is not None + page.on("download", self._download_handler) # type: ignore + if self.to_resize_viewport and self.viewport_width and self.viewport_height: + await page.set_viewport_size({"width": self.viewport_width, "height": self.viewport_height}) + await self.sleep(page, 0.2) + await page.add_init_script(path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js")) + await page.wait_for_load_state() + + async def back(self, page: Page) -> None: + """ + Navigate back to the previous page. + + Args: + page (Page): The Playwright page object. + """ + assert page is not None + await page.go_back() + + async def visit_page(self, page: Page, url: str) -> Tuple[bool, bool]: + """ + Visit a specified URL. + + Args: + page (Page): The Playwright page object. + url (str): The URL to visit. + + Returns: + Tuple[bool, bool]: A tuple indicating whether to reset prior metadata hash and last download. + """ + assert page is not None + reset_prior_metadata_hash = False + reset_last_download = False + try: + # Regular webpage + await page.goto(url) + await page.wait_for_load_state() + reset_prior_metadata_hash = True + except Exception as e_outer: + # Downloaded file + if self.downloads_folder and "net::ERR_ABORTED" in str(e_outer): + async with page.expect_download() as download_info: + try: + await page.goto(url) + except Exception as e_inner: + if "net::ERR_ABORTED" in str(e_inner): + pass + else: + raise e_inner + download = await download_info.value + fname = os.path.join(self.downloads_folder, download.suggested_filename) + await download.save_as(fname) + message = f"

Successfully downloaded '{download.suggested_filename}' to local path:

{fname}

" + await page.goto( + "data:text/html;base64," + base64.b64encode(message.encode("utf-8")).decode("utf-8") + ) + reset_last_download = True + else: + raise e_outer + return reset_prior_metadata_hash, reset_last_download + + async def page_down(self, page: Page) -> None: + """ + Scroll the page down by one viewport height minus 50 pixels. + + Args: + page (Page): The Playwright page object. + """ + assert page is not None + await page.evaluate(f"window.scrollBy(0, {self.viewport_height-50});") + + async def page_up(self, page: Page) -> None: + """ + Scroll the page up by one viewport height minus 50 pixels. + + Args: + page (Page): The Playwright page object. + """ + assert page is not None + await page.evaluate(f"window.scrollBy(0, -{self.viewport_height-50});") + + async def gradual_cursor_animation( + self, page: Page, start_x: float, start_y: float, end_x: float, end_y: float + ) -> None: + """ + Animate the cursor movement gradually from start to end coordinates. + + Args: + page (Page): The Playwright page object. + start_x (float): The starting x-coordinate. + start_y (float): The starting y-coordinate. + end_x (float): The ending x-coordinate. + end_y (float): The ending y-coordinate. + """ + # animation helper + steps = 20 + for step in range(steps): + x = start_x + (end_x - start_x) * (step / steps) + y = start_y + (end_y - start_y) * (step / steps) + # await page.mouse.move(x, y, steps=1) + await page.evaluate(f""" + (function() {{ + let cursor = document.getElementById('red-cursor'); + cursor.style.left = '{x}px'; + cursor.style.top = '{y}px'; + }})(); + """) + await asyncio.sleep(0.05) + + self.last_cursor_position = (end_x, end_y) + + async def add_cursor_box(self, page: Page, identifier: str) -> None: + """ + Add a red cursor box around the element with the given identifier. + + Args: + page (Page): The Playwright page object. + identifier (str): The element identifier. + """ + # animation helper + await page.evaluate(f""" + (function() {{ + let elm = document.querySelector("[__elementId='{identifier}']"); + if (elm) {{ + elm.style.transition = 'border 0.3s ease-in-out'; + elm.style.border = '2px solid red'; + }} + }})(); + """) + await asyncio.sleep(0.3) + + # Create a red cursor + await page.evaluate(""" + (function() { + let cursor = document.createElement('div'); + cursor.id = 'red-cursor'; + cursor.style.width = '10px'; + cursor.style.height = '10px'; + cursor.style.backgroundColor = 'red'; + cursor.style.position = 'absolute'; + cursor.style.borderRadius = '50%'; + cursor.style.zIndex = '10000'; + document.body.appendChild(cursor); + })(); + """) + + async def remove_cursor_box(self, page: Page, identifier: str) -> None: + """ + Remove the red cursor box around the element with the given identifier. + + Args: + page (Page): The Playwright page object. + identifier (str): The element identifier. + """ + # Remove the highlight and cursor + await page.evaluate(f""" + (function() {{ + let elm = document.querySelector("[__elementId='{identifier}']"); + if (elm) {{ + elm.style.border = ''; + }} + let cursor = document.getElementById('red-cursor'); + if (cursor) {{ + cursor.remove(); + }} + }})(); + """) + + async def click_id(self, page: Page, identifier: str) -> Page | None: + """ + Click the element with the given identifier. + + Args: + page (Page): The Playwright page object. + identifier (str): The element identifier. + + Returns: + Page | None: The new page if a new page is opened, otherwise None. + """ + new_page: Page | None = None + assert page is not None + target = page.locator(f"[__elementId='{identifier}']") + + # See if it exists + try: + await target.wait_for(timeout=5000) + except TimeoutError: + raise ValueError("No such element.") from None + + # Click it + await target.scroll_into_view_if_needed() + await asyncio.sleep(0.3) + + box = cast(Dict[str, Union[int, float]], await target.bounding_box()) + + if self.animate_actions: + await self.add_cursor_box(page, identifier) + # Move cursor to the box slowly + start_x, start_y = self.last_cursor_position + end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2 + await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y) + await asyncio.sleep(0.1) + + try: + # Give it a chance to open a new page + async with page.expect_event("popup", timeout=1000) as page_info: # type: ignore + await page.mouse.click(end_x, end_y, delay=10) + new_page = await page_info.value # type: ignore + assert isinstance(new_page, Page) + await self.on_new_page(new_page) + except TimeoutError: + pass + await self.remove_cursor_box(page, identifier) + + else: + try: + # Give it a chance to open a new page + async with page.expect_event("popup", timeout=1000) as page_info: # type: ignore + await page.mouse.click(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2, delay=10) + new_page = await page_info.value # type: ignore + assert isinstance(new_page, Page) + await self.on_new_page(new_page) + except TimeoutError: + pass + return new_page # type: ignore + + async def hover_id(self, page: Page, identifier: str) -> None: + """ + Hover the mouse over the element with the given identifier. + + Args: + page (Page): The Playwright page object. + identifier (str): The element identifier. + """ + assert page is not None + target = page.locator(f"[__elementId='{identifier}']") + + # See if it exists + try: + await target.wait_for(timeout=5000) + except TimeoutError: + raise ValueError("No such element.") from None + + # Hover over it + await target.scroll_into_view_if_needed() + await asyncio.sleep(0.3) + + box = cast(Dict[str, Union[int, float]], await target.bounding_box()) + + if self.animate_actions: + await self.add_cursor_box(page, identifier) + # Move cursor to the box slowly + start_x, start_y = self.last_cursor_position + end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2 + await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y) + await asyncio.sleep(0.1) + await page.mouse.move(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2) + + await self.remove_cursor_box(page, identifier) + else: + await page.mouse.move(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2) + + async def fill_id(self, page: Page, identifier: str, value: str, press_enter: bool = True) -> None: + """ + Fill the element with the given identifier with the specified value. + + Args: + page (Page): The Playwright page object. + identifier (str): The element identifier. + value (str): The value to fill. + """ + assert page is not None + target = page.locator(f"[__elementId='{identifier}']") + + # See if it exists + try: + await target.wait_for(timeout=5000) + except TimeoutError: + raise ValueError("No such element.") from None + + # Fill it + await target.scroll_into_view_if_needed() + box = cast(Dict[str, Union[int, float]], await target.bounding_box()) + + if self.animate_actions: + await self.add_cursor_box(page, identifier) + # Move cursor to the box slowly + start_x, start_y = self.last_cursor_position + end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2 + await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y) + await asyncio.sleep(0.1) + + # Focus on the element + await target.focus() + if self.animate_actions: + # fill char by char to mimic human speed for short text and type fast for long text + if len(value) < 100: + delay_typing_speed = 50 + 100 * random.random() + else: + delay_typing_speed = 10 + await target.press_sequentially(value, delay=delay_typing_speed) + else: + try: + await target.fill(value) + except PlaywrightError: + await target.press_sequentially(value) + if press_enter: + await target.press("Enter") + + if self.animate_actions: + await self.remove_cursor_box(page, identifier) + + async def scroll_id(self, page: Page, identifier: str, direction: str) -> None: + """ + Scroll the element with the given identifier in the specified direction. + + Args: + page (Page): The Playwright page object. + identifier (str): The element identifier. + direction (str): The direction to scroll ("up" or "down"). + """ + assert page is not None + await page.evaluate( + f""" + (function() {{ + let elm = document.querySelector("[__elementId='{identifier}']"); + if (elm) {{ + if ("{direction}" == "up") {{ + elm.scrollTop = Math.max(0, elm.scrollTop - elm.clientHeight); + }} + else {{ + elm.scrollTop = Math.min(elm.scrollHeight - elm.clientHeight, elm.scrollTop + elm.clientHeight); + }} + }} + }})(); + """ + ) + + async def get_webpage_text(self, page: Page, n_lines: int = 50) -> str: + """ + Retrieve the text content of the web page. + + Args: + page (Page): The Playwright page object. + n_lines (int): The number of lines to return from the page inner text. + + Returns: + str: The text content of the page. + """ + assert page is not None + try: + text_in_viewport = await page.evaluate("""() => { + return document.body.innerText; + }""") + text_in_viewport = "\n".join(text_in_viewport.split("\n")[:n_lines]) + # remove empty lines + text_in_viewport = "\n".join([line for line in text_in_viewport.split("\n") if line.strip()]) + assert isinstance(text_in_viewport, str) + return text_in_viewport + except Exception: + return "" + + async def get_visible_text(self, page: Page) -> str: + """ + Retrieve the text content of the browser viewport (approximately). + + Args: + page (Page): The Playwright page object. + + Returns: + str: The text content of the page. + """ + assert page is not None + try: + await page.evaluate(self._page_script) + except Exception: + pass + result = await page.evaluate("MultimodalWebSurfer.getVisibleText();") + assert isinstance(result, str) + return result + + async def get_page_markdown(self, page: Page) -> str: + """ + Retrieve the markdown content of the web page. + Currently not implemented. + + Args: + page (Page): The Playwright page object. + + Returns: + str: The markdown content of the page. + """ + assert page is not None + if self._markdown_converter is None and markitdown is not None: + self._markdown_converter = markitdown.MarkItDown() + assert self._markdown_converter is not None + html = await page.evaluate("document.documentElement.outerHTML;") + res = self._markdown_converter.convert_stream( + io.BytesIO(html.encode("utf-8")), file_extension=".html", url=page.url + ) + assert hasattr(res, "text_content") and isinstance(res.text_content, str) + return res.text_content + else: + return await self.get_webpage_text(page, n_lines=200) diff --git a/agent_dhal/agentdhal_extensions/auth/azure/__init__.py b/agent_dhal/agentdhal_extensions/auth/azure/__init__.py new file mode 100644 index 0000000..88d677e --- /dev/null +++ b/agent_dhal/agentdhal_extensions/auth/azure/__init__.py @@ -0,0 +1,56 @@ +from typing import List + +from agentdhal_core import Component, ComponentBase +from pydantic import BaseModel +from typing_extensions import Self + +from azure.core.credentials import TokenProvider +from azure.identity import DefaultAzureCredential, get_bearer_token_provider + + +class TokenProviderConfig(BaseModel): + provider_kind: str + scopes: List[str] + + +class AzureTokenProvider(ComponentBase[TokenProviderConfig], Component[TokenProviderConfig]): + component_type = "token_provider" + component_config_schema = TokenProviderConfig + component_provider_override = "agentdhal_extensions.auth.azure.AzureTokenProvider" + + def __init__(self, credential: TokenProvider, *scopes: str): + self.credential = credential + self.scopes = list(scopes) + self.provider = get_bearer_token_provider(self.credential, *self.scopes) + + def __call__(self) -> str: + return self.provider() + + def _to_config(self) -> TokenProviderConfig: + """Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance. + + Returns: + T: The configuration of the component. + """ + + if isinstance(self.credential, DefaultAzureCredential): + # NOTE: we are not currently inspecting the chained credentials, so this could result in a loss of information + return TokenProviderConfig(provider_kind="DefaultAzureCredential", scopes=self.scopes) + else: + raise ValueError("Only DefaultAzureCredential is supported") + + @classmethod + def _from_config(cls, config: TokenProviderConfig) -> Self: + """Create a new instance of the component from a configuration object. + + Args: + config (T): The configuration object. + + Returns: + Self: The new instance of the component. + """ + + if config.provider_kind == "DefaultAzureCredential": + return cls(DefaultAzureCredential(), *config.scopes) + else: + raise ValueError("Only DefaultAzureCredential is supported") diff --git a/agent_dhal/agentdhal_extensions/cache_store/__init__.py b/agent_dhal/agentdhal_extensions/cache_store/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent_dhal/agentdhal_extensions/cache_store/diskcache.py b/agent_dhal/agentdhal_extensions/cache_store/diskcache.py new file mode 100644 index 0000000..9157b2a --- /dev/null +++ b/agent_dhal/agentdhal_extensions/cache_store/diskcache.py @@ -0,0 +1,46 @@ +from typing import Any, Optional, TypeVar, cast + +import diskcache +from agentdhal_core import CacheStore, Component +from pydantic import BaseModel +from typing_extensions import Self + +T = TypeVar("T") + + +class DiskCacheStoreConfig(BaseModel): + """Configuration for DiskCacheStore""" + + directory: str # Path where cache is stored + # Could add other diskcache.Cache parameters like size_limit, etc. + + +class DiskCacheStore(CacheStore[T], Component[DiskCacheStoreConfig]): + """ + A typed CacheStore implementation that uses diskcache as the underlying storage. + See :class:`~agentdhal_extensions.models.cache.ChatCompletionCache` for an example of usage. + + Args: + cache_instance: An instance of diskcache.Cache. + The user is responsible for managing the DiskCache instance's lifetime. + """ + + component_config_schema = DiskCacheStoreConfig + component_provider_override = "agentdhal_extensions.cache_store.diskcache.DiskCacheStore" + + def __init__(self, cache_instance: diskcache.Cache): # type: ignore[no-any-unimported] + self.cache = cache_instance + + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: + return cast(Optional[T], self.cache.get(key, default)) # type: ignore[reportUnknownMemberType] + + def set(self, key: str, value: T) -> None: + self.cache.set(key, cast(Any, value)) # type: ignore[reportUnknownMemberType] + + def _to_config(self) -> DiskCacheStoreConfig: + # Get directory from cache instance + return DiskCacheStoreConfig(directory=self.cache.directory) + + @classmethod + def _from_config(cls, config: DiskCacheStoreConfig) -> Self: + return cls(cache_instance=diskcache.Cache(config.directory)) # type: ignore[no-any-return] diff --git a/agent_dhal/agentdhal_extensions/cache_store/redis.py b/agent_dhal/agentdhal_extensions/cache_store/redis.py new file mode 100644 index 0000000..64b7770 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/cache_store/redis.py @@ -0,0 +1,142 @@ +import json +from typing import Any, Dict, Optional, TypeVar, cast + +import redis +from agentdhal_core import CacheStore, Component +from pydantic import BaseModel +from typing_extensions import Self + +T = TypeVar("T") + + +class RedisStoreConfig(BaseModel): + """Configuration for RedisStore""" + + host: str = "localhost" + port: int = 6379 + db: int = 0 + # Add other relevant redis connection parameters + username: Optional[str] = None + password: Optional[str] = None + ssl: bool = False + socket_timeout: Optional[float] = None + + +class RedisStore(CacheStore[T], Component[RedisStoreConfig]): + """ + A typed CacheStore implementation that uses redis as the underlying storage. + See :class:`~agentdhal_extensions.models.cache.ChatCompletionCache` for an example of usage. + + This implementation provides automatic serialization and deserialization for: + - Pydantic models (uses model_dump_json/model_validate_json) + - Primitive types (strings, numbers, etc.) + + + Args: + cache_instance: An instance of `redis.Redis`. + The user is responsible for managing the Redis instance's lifetime. + """ + + component_config_schema = RedisStoreConfig + component_provider_override = "agentdhal_extensions.cache_store.redis.RedisStore" + + def __init__(self, redis_instance: redis.Redis): + self.cache = redis_instance + + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: + """ + Retrieve a value from the Redis cache. + + This method handles both primitive values and complex objects: + - Pydantic models are automatically deserialized from JSON + - Primitive values (strings, numbers, etc.) are returned as-is + - If deserialization fails, returns the raw value or default + + Args: + key: The key to retrieve + default: Value to return if key doesn't exist + + Returns: + The value if found and properly deserialized, otherwise the default + """ + try: + raw_value = self.cache.get(key) + if raw_value is None: + return default + + if isinstance(raw_value, bytes): + try: + # First try to decode as UTF-8 string + decoded_str = raw_value.decode("utf-8") + try: + # Try to parse as JSON and return the parsed object + parsed_json = json.loads(decoded_str) + return cast(Optional[T], parsed_json) + except json.JSONDecodeError: + # If not valid JSON, return the decoded string. + return cast(Optional[T], decoded_str) + except UnicodeDecodeError: + return default + else: + # Backward compatibility for primitives + return cast(Optional[T], raw_value) + except (redis.RedisError, ConnectionError): + # Log Redis-specific errors but return default gracefully + return default + + def set(self, key: str, value: T) -> None: + """ + Store a value in the Redis cache. + + This method handles both primitive values and complex objects: + - Pydantic models are automatically serialized to JSON + - Primitive values (strings, numbers, etc.) are stored as-is + + Args: + key: The key to store the value under + value: The value to store + """ + try: + if isinstance(value, BaseModel): + # Serialize Pydantic models to JSON + serialized_value = value.model_dump_json().encode("utf-8") + self.cache.set(key, serialized_value) + else: + # Backward compatibility for primitives + self.cache.set(key, cast(Any, value)) + except (redis.RedisError, ConnectionError, UnicodeEncodeError): + # Log the error but don't re-raise to maintain robustness + pass + + def _to_config(self) -> RedisStoreConfig: + # Extract connection info from redis instance + connection_pool = self.cache.connection_pool + connection_kwargs: Dict[str, Any] = connection_pool.connection_kwargs # type: ignore[reportUnknownMemberType] + + username = connection_kwargs.get("username") + password = connection_kwargs.get("password") + socket_timeout = connection_kwargs.get("socket_timeout") + + return RedisStoreConfig( + host=str(connection_kwargs.get("host", "localhost")), + port=int(connection_kwargs.get("port", 6379)), + db=int(connection_kwargs.get("db", 0)), + username=str(username) if username is not None else None, + password=str(password) if password is not None else None, + ssl=bool(connection_kwargs.get("ssl", False)), + socket_timeout=float(socket_timeout) if socket_timeout is not None else None, + ) + + @classmethod + def _from_config(cls, config: RedisStoreConfig) -> Self: + # Create new redis instance from config + redis_instance = redis.Redis( + host=config.host, + port=config.port, + db=config.db, + username=config.username, + password=config.password, + ssl=config.ssl, + socket_timeout=config.socket_timeout, + ) + return cls(redis_instance=redis_instance) diff --git a/agent_dhal/agentdhal_extensions/code_executors/_common.py b/agent_dhal/agentdhal_extensions/code_executors/_common.py new file mode 100644 index 0000000..509b7dd --- /dev/null +++ b/agent_dhal/agentdhal_extensions/code_executors/_common.py @@ -0,0 +1,199 @@ +import inspect +import re +import shutil +from dataclasses import dataclass +from pathlib import Path +from textwrap import dedent, indent +from typing import Any, Callable, Optional, Sequence, Set, TypeVar, Union + +from agentdhal_core.code_executor import Alias, CodeResult, FunctionWithRequirements, FunctionWithRequirementsStr, Import +from typing_extensions import ParamSpec + + +@dataclass +class CommandLineCodeResult(CodeResult): + """A code result class for command line code executor.""" + + code_file: Optional[str] + + +T = TypeVar("T") +P = ParamSpec("P") + + +def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str: + if isinstance(func, FunctionWithRequirementsStr): + return func.func + + code = inspect.getsource(func) + # Strip the decorator + if code.startswith("@"): + code = code[code.index("\n") + 1 :] + return code + + +def _import_to_str(im: Import) -> str: + if isinstance(im, str): + return f"import {im}" + elif isinstance(im, Alias): + return f"import {im.name} as {im.alias}" + else: + + def to_str(i: Union[str, Alias]) -> str: + if isinstance(i, str): + return i + else: + return f"{i.name} as {i.alias}" + + imports = ", ".join(map(to_str, im.imports)) + return f"from {im.module} import {imports}" + + +def build_python_functions_file( + funcs: Sequence[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]], +) -> str: + """:meta private:""" + # First collect all global imports + global_imports: Set[Import] = set() + for func in funcs: + if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)): + global_imports.update(func.global_imports) + + content = "\n".join(map(_import_to_str, global_imports)) + "\n\n" + + for func in funcs: + content += _to_code(func) + "\n\n" + + return content + + +def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str: + """Generate a stub for a function as a string + + Args: + func (Callable[..., Any]): The function to generate a stub for + + Returns: + str: The stub for the function + """ + if isinstance(func, FunctionWithRequirementsStr): + return to_stub(func.compiled_func) + + content = f"def {func.__name__}{inspect.signature(func)}:\n" + docstring = func.__doc__ + + if docstring: + docstring = dedent(docstring) + docstring = '"""' + docstring + '"""' + docstring = indent(docstring, " ") + content += docstring + "\n" + + content += " ..." + return content + + +# Raises ValueError if the file is not in the workspace +def get_file_name_from_content(code: str, workspace_path: Path) -> Optional[str]: + first_line = code.split("\n")[0] + # TODO - support other languages + if first_line.startswith("# filename:"): + filename = first_line.split(":")[1].strip() + + # Handle relative paths in the filename + path = Path(filename) + if not path.is_absolute(): + path = workspace_path / path + path = path.resolve() + # Throws an error if the file is not in the workspace + relative = path.relative_to(workspace_path.resolve()) + return str(relative) + + return None + + +def silence_pip(code: str, lang: str) -> str: + """Apply -qqq flag to pip install commands.""" + if lang == "python": + regex = r"^! ?pip install" + elif lang in ["bash", "shell", "sh", "pwsh", "powershell", "ps1"]: + regex = r"^pip install" + else: + return code + + # Find lines that start with pip install and make sure "-qqq" flag is added. + lines = code.split("\n") + for i, line in enumerate(lines): + # use regex to find lines that start with pip install. + match = re.search(regex, line) + if match is not None: + if "-qqq" not in line: + lines[i] = line.replace(match.group(0), match.group(0) + " -qqq") + return "\n".join(lines) + + +def get_required_packages(code: str, lang: str) -> set[str]: + ret: set[str] = set() + if lang == "python": + regex = r"^! ?pip install(.*)$" + else: + return ret + + # Find lines that start with pip install and make sure "-qqq" flag is added. + lines = code.split("\n") + for _, line in enumerate(lines): + # use regex to find lines that start with pip install. + match = re.search(regex, line) + if match is not None: + reqs = match.group(1).split(",") + ret = {req.strip(" ") for req in reqs} + return ret + + +PYTHON_VARIANTS = ["python", "Python", "py"] + + +def lang_to_cmd(lang: str) -> str: + if lang in PYTHON_VARIANTS: + return "python" + if lang.startswith("python") or lang in ["bash", "sh"]: + return lang + if lang in ["shell"]: + return "sh" + if lang in ["pwsh", "powershell", "ps1"]: + # Check if pwsh is available, otherwise fall back to powershell + if shutil.which("pwsh") is not None: + return "pwsh" + elif shutil.which("powershell") is not None: + return "powershell" + else: + raise ValueError("Powershell or pwsh is not installed. Please install one of them.") + else: + raise ValueError(f"Unsupported language: {lang}") + + +# Regular expression for finding a code block +# ```[ \t]*(\w+)?[ \t]*\r?\n(.*?)[ \t]*\r?\n``` Matches multi-line code blocks. +# The [ \t]* matches the potential spaces before language name. +# The (\w+)? matches the language, where the ? indicates it is optional. +# The [ \t]* matches the potential spaces (not newlines) after language name. +# The \r?\n makes sure there is a linebreak after ```. +# The (.*?) matches the code itself (non-greedy). +# The \r?\n makes sure there is a linebreak before ```. +# The [ \t]* matches the potential spaces before closing ``` (the spec allows indentation). +CODE_BLOCK_PATTERN = r"```[ \t]*(\w+)?[ \t]*\r?\n(.*?)\r?\n[ \t]*```" + + +def infer_lang(code: str) -> str: + """infer the language for the code. + TODO: make it robust. + """ + if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "): + return "sh" + + # check if code is a valid python code + try: + compile(code, "test", "exec") + return "python" + except SyntaxError: + # not a valid python code + return "unknown" diff --git a/agent_dhal/agentdhal_extensions/code_executors/azure/__init__.py b/agent_dhal/agentdhal_extensions/code_executors/azure/__init__.py new file mode 100644 index 0000000..33c79b0 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/code_executors/azure/__init__.py @@ -0,0 +1,3 @@ +from ._azure_container_code_executor import ACADynamicSessionsCodeExecutor, TokenProvider + +__all__ = ["TokenProvider", "ACADynamicSessionsCodeExecutor"] diff --git a/agent_dhal/agentdhal_extensions/code_executors/azure/_azure_container_code_executor.py b/agent_dhal/agentdhal_extensions/code_executors/azure/_azure_container_code_executor.py new file mode 100644 index 0000000..b5d8a14 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/code_executors/azure/_azure_container_code_executor.py @@ -0,0 +1,522 @@ +# Credit to original authors + +from __future__ import annotations + +import asyncio +import os +import tempfile +import warnings +from pathlib import Path +from string import Template +from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Protocol, Sequence, Union +from uuid import uuid4 + +import aiohttp + +# async functions shouldn't use open() +from anyio import open_file +from agentdhal_core import CancellationToken +from agentdhal_core.code_executor import ( + CodeBlock, + CodeExecutor, + CodeResult, + FunctionWithRequirements, + FunctionWithRequirementsStr, +) +from typing_extensions import ParamSpec + +from .._common import build_python_functions_file, get_required_packages, to_stub + +if TYPE_CHECKING: + from azure.core.credentials import AccessToken + +PYTHON_VARIANTS = ["python", "Python", "py"] + +__all__ = ("ACADynamicSessionsCodeExecutor", "TokenProvider") + +A = ParamSpec("A") + + +class TokenProvider(Protocol): + def get_token( + self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any + ) -> AccessToken: ... + + +class ACADynamicSessionsCodeExecutor(CodeExecutor): + """(Experimental) A code executor class that executes code through a an Azure + Container Apps Dynamic Sessions instance. + + .. note:: + + This class requires the :code:`azure` extra for the :code:`autogen-ext` package: + + .. code-block:: bash + + pip install "agentdhal-ext[azure]" + + .. caution:: + + **This will execute LLM generated code on an Azure dynamic code container.** + + The execution environment is similar to that of a jupyter notebook which allows for incremental code execution. The parameter functions are executed in order once at the beginning of each session. Each code block is then executed serially and in the order they are received. Each environment has a statically defined set of available packages which cannot be changed. + Currently, attempting to use packages beyond what is available on the environment will result in an error. To get the list of supported packages, call the `get_available_packages` function. + Currently the only supported language is Python. + For Python code, use the language "python" for the code block. + + Args: + pool_management_endpoint (str): The azure container apps dynamic sessions endpoint. + credential (TokenProvider): An object that implements the get_token function. + timeout (int): The timeout for the execution of any single code block. Default is 60. + work_dir (str): The working directory for the code execution. If None, + a default working directory will be used. The default working + directory is a temporal directory. + functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list. + suppress_result_output bool: By default the executor will attach any result info in the execution response to the result outpu. Set this to True to prevent this. + session_id (str): The session id for the code execution (passed to Dynamic Sessions). If None, a new session id will be generated. Default is None. Note this value will be reset when calling `restart` + + .. note:: + Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning. + """ + + SUPPORTED_LANGUAGES: ClassVar[List[str]] = [ + "python", + ] + FUNCTION_PROMPT_TEMPLATE: ClassVar[str] = """You have access to the following user defined functions. + +$functions""" + + _AZURE_API_VER = "2024-02-02-preview" + + def __init__( + self, + pool_management_endpoint: str, + credential: TokenProvider, + timeout: int = 60, + work_dir: Union[Path, str, None] = None, + functions: Sequence[ + Union[ + FunctionWithRequirements[Any, A], + Callable[..., Any], + FunctionWithRequirementsStr, + ] + ] = [], + functions_module: str = "functions", + suppress_result_output: bool = False, + session_id: Optional[str] = None, + ): + if timeout < 1: + raise ValueError("Timeout must be greater than or equal to 1.") + + self._work_dir: Optional[Path] = None + self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None + + # If a user specifies a working directory, use that + if work_dir is not None: + if isinstance(work_dir, str): + self._work_dir = Path(work_dir) + else: + self._work_dir = work_dir + # Create the directory if it doesn't exist + self._work_dir.mkdir(exist_ok=True, parents=True) + # If a user does not specify a working directory, use the default directory (tempfile.TemporaryDirectory) + else: + self._temp_dir = tempfile.TemporaryDirectory() + temp_dir_path = Path(self._temp_dir.name) + temp_dir_path.mkdir(exist_ok=True, parents=True) + + self._started = False + + # Rest of initialization remains the same + self._functions_module = functions_module + self._timeout = timeout + self._functions = functions + self._func_code: Optional[str] = None + + # Setup could take some time so we intentionally wait for the first code block to do it. + if len(functions) > 0: + self._setup_functions_complete = False + else: + self._setup_functions_complete = True + + self._suppress_result_output = suppress_result_output + + self._pool_management_endpoint = pool_management_endpoint + self._access_token: str | None = None + self._session_id: str = session_id or str(uuid4()) + self._available_packages: set[str] | None = None + self._credential: TokenProvider = credential + # cwd needs to be set to /mnt/data to properly read uploaded files and download written files + self._setup_cwd_complete = False + + # TODO: expiration? + def _ensure_access_token(self) -> None: + if not self._access_token: + scope = "https://dynamicsessions.io/.default" + self._access_token = self._credential.get_token(scope).token + + def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str: + """(Experimental) Format the functions for a prompt. + + The template includes one variable: + - `$functions`: The functions formatted as stubs with two newlines between each function. + + Args: + prompt_template (str): The prompt template. Default is the class default. + + Returns: + str: The formatted prompt. + """ + + template = Template(prompt_template) + return template.substitute( + functions="\n\n".join([to_stub(func) for func in self._functions]), + ) + + @property + def functions_module(self) -> str: + """(Experimental) The module name for the functions.""" + return self._functions_module + + @property + def functions(self) -> List[str]: + raise NotImplementedError + + @property + def timeout(self) -> int: + """(Experimental) The timeout for code execution.""" + return self._timeout + + @property + def work_dir(self) -> Path: + # If a user specifies a working directory, use that + if self._work_dir is not None: + # If a user specifies the current directory, warn them that this is deprecated + if self._work_dir == Path("."): + warnings.warn( + "Using the current directory as work_dir is deprecated", + DeprecationWarning, + stacklevel=2, + ) + return self._work_dir + # If a user does not specify a working directory, use the default directory (tempfile.TemporaryDirectory) + elif self._temp_dir is not None: + return Path(self._temp_dir.name) + else: + raise RuntimeError("Working directory not properly initialized") + + def _construct_url(self, path: str) -> str: + endpoint = self._pool_management_endpoint + if not endpoint.endswith("/"): + endpoint += "/" + url = endpoint + f"{path}?api-version={self._AZURE_API_VER}&identifier={self._session_id}" + return url + + async def get_available_packages(self, cancellation_token: CancellationToken) -> set[str]: + if self._available_packages is not None: + return self._available_packages + avail_pkgs = """ +import pkg_resources\n[d.project_name for d in pkg_resources.working_set] +""" + ret = await self._execute_code_dont_check_setup( + [CodeBlock(code=avail_pkgs, language="python")], cancellation_token + ) + if ret.exit_code != 0: + raise ValueError(f"Failed to get list of available packages: {ret.output.strip()}") + pkgs = ret.output.strip("[]") + pkglist = pkgs.split(",\n") + return {pkg.strip(" '") for pkg in pkglist} + + async def _populate_available_packages(self, cancellation_token: CancellationToken) -> None: + self._available_packages = await self.get_available_packages(cancellation_token) + + async def _setup_functions(self, cancellation_token: CancellationToken) -> None: + if not self._func_code: + self._func_code = build_python_functions_file(self._functions) + + # Check required function imports and packages + lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)] + # Should we also be checking the imports? + + flattened_packages = [item for sublist in lists_of_packages for item in sublist] + required_packages = set(flattened_packages) + + if self._available_packages is None: + await self._populate_available_packages(cancellation_token) + + if self._available_packages is not None: + missing_pkgs = set(required_packages - self._available_packages) + if len(missing_pkgs) > 0: + raise ValueError(f"Packages unavailable in environment: {missing_pkgs}") + + func_file = self.work_dir / f"{self._functions_module}.py" + func_file.write_text(self._func_code) + + # Attempt to load the function file to check for syntax errors, imports etc. + exec_result = await self._execute_code_dont_check_setup( + [CodeBlock(code=self._func_code, language="python")], cancellation_token + ) + + if exec_result.exit_code != 0: + raise ValueError(f"Functions failed to load: {exec_result.output.strip()}") + + self._setup_functions_complete = True + + async def _setup_cwd(self, cancellation_token: CancellationToken) -> None: + # Change the cwd to /mnt/data to properly have access to uploaded files + exec_result = await self._execute_code_dont_check_setup( + [CodeBlock(code="import os; os.chdir('/mnt/data')", language="python")], cancellation_token + ) + + if exec_result.exit_code != 0: + raise ValueError("Failed to set up Azure container working directory") + self._setup_cwd_complete = True + + async def get_file_list(self, cancellation_token: CancellationToken) -> List[str]: + self._ensure_access_token() + timeout = aiohttp.ClientTimeout(total=float(self._timeout)) + headers = { + "Authorization": f"Bearer {self._access_token}", + } + url = self._construct_url("files") + async with aiohttp.ClientSession(timeout=timeout) as client: + task = asyncio.create_task( + client.get( + url, + headers=headers, + ) + ) + cancellation_token.link_future(task) + try: + resp = await task + resp.raise_for_status() + data = await resp.json() + except asyncio.TimeoutError as e: + # e.add_note is only in py 3.11+ + raise asyncio.TimeoutError("Timeout getting file list") from e + except asyncio.CancelledError as e: + # e.add_note is only in py 3.11+ + raise asyncio.CancelledError("File list retrieval cancelled") from e + except aiohttp.ClientResponseError as e: + raise ConnectionError("Error while getting file list") from e + + values = data["value"] + file_info_list: List[str] = [] + for value in values: + file = value["properties"] + file_info_list.append(file["filename"]) + return file_info_list + + async def upload_files(self, files: List[Union[Path, str]], cancellation_token: CancellationToken) -> None: + self._ensure_access_token() + # TODO: Better to use the client auth system rather than headers + headers = {"Authorization": f"Bearer {self._access_token}"} + url = self._construct_url("files/upload") + timeout = aiohttp.ClientTimeout(total=float(self._timeout)) + async with aiohttp.ClientSession(timeout=timeout) as client: + for file in files: + file_path = self.work_dir / file + if not file_path.is_file(): + # TODO: what to do here? + raise FileNotFoundError(f"{file} does not exist") + + data = aiohttp.FormData() + async with await open_file(file_path, "rb") as f: + data.add_field( + "file", + f, + filename=os.path.basename(file_path), + content_type="application/octet-stream", + ) + + task = asyncio.create_task( + client.post( + url, + headers=headers, + data=data, + ) + ) + + cancellation_token.link_future(task) + try: + resp = await task + resp.raise_for_status() + + except asyncio.TimeoutError as e: + # e.add_note is only in py 3.11+ + raise asyncio.TimeoutError("Timeout uploading files") from e + except asyncio.CancelledError as e: + # e.add_note is only in py 3.11+ + raise asyncio.CancelledError("Uploading files cancelled") from e + except aiohttp.ClientResponseError as e: + raise ConnectionError("Error while uploading files") from e + + async def download_files(self, files: List[Union[Path, str]], cancellation_token: CancellationToken) -> List[str]: + self._ensure_access_token() + available_files = await self.get_file_list(cancellation_token) + # TODO: Better to use the client auth system rather than headers + headers = {"Authorization": f"Bearer {self._access_token}"} + timeout = aiohttp.ClientTimeout(total=float(self._timeout)) + local_paths: List[str] = [] + async with aiohttp.ClientSession(timeout=timeout) as client: + for file in files: + if file not in available_files: + # TODO: what's the right thing to do here? + raise FileNotFoundError(f"{file} does not exist") + + url = self._construct_url(f"files/content/{file}") + + task = asyncio.create_task( + client.get( + url, + headers=headers, + ) + ) + cancellation_token.link_future(task) + try: + resp = await task + resp.raise_for_status() + local_path = self.work_dir / file + local_paths.append(str(local_path)) + async with await open_file(local_path, "wb") as f: + await f.write(await resp.read()) + except asyncio.TimeoutError as e: + # e.add_note is only in py 3.11+ + raise asyncio.TimeoutError("Timeout downloading files") from e + except asyncio.CancelledError as e: + # e.add_note is only in py 3.11+ + raise asyncio.CancelledError("Downloading files cancelled") from e + except aiohttp.ClientResponseError as e: + raise ConnectionError("Error while downloading files") from e + return local_paths + + async def execute_code_blocks( + self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken + ) -> CodeResult: + """(Experimental) Execute the code blocks and return the result. + + Args: + code_blocks (List[CodeBlock]): The code blocks to execute. + cancellation_token (CancellationToken): a token to cancel the operation + input_files (Optional[Union[Path, str]]): Any files the code blocks will need to access + + Returns: + CodeResult: The result of the code execution.""" + + self._ensure_access_token() + if self._available_packages is None: + await self._populate_available_packages(cancellation_token) + if not self._setup_functions_complete: + await self._setup_functions(cancellation_token) + if not self._setup_cwd_complete: + await self._setup_cwd(cancellation_token) + + return await self._execute_code_dont_check_setup(code_blocks, cancellation_token) + + # The http call here should be replaced by an actual Azure client call once its available + async def _execute_code_dont_check_setup( + self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken + ) -> CodeResult: + logs_all = "" + exitcode = 0 + + # TODO: Better to use the client auth system rather than headers + assert self._access_token is not None + headers = { + "Authorization": f"Bearer {self._access_token}", + "Content-Type": "application/json", + } + properties = { + "codeInputType": "inline", + "executionType": "synchronous", + "code": "", # Filled in later + } + url = self._construct_url("code/execute") + timeout = aiohttp.ClientTimeout(total=float(self._timeout)) + async with aiohttp.ClientSession(timeout=timeout) as client: + for code_block in code_blocks: + lang, code = code_block.language, code_block.code + lang = lang.lower() + + if lang in PYTHON_VARIANTS: + lang = "python" + + if lang not in self.SUPPORTED_LANGUAGES: + # In case the language is not supported, we return an error message. + exitcode = 1 + logs_all += "\n" + f"unknown language {lang}" + break + + if self._available_packages is not None: + req_pkgs = get_required_packages(code, lang) + missing_pkgs = set(req_pkgs - self._available_packages) + if len(missing_pkgs) > 0: + # In case the code requires packages that are not available in the environment + exitcode = 1 + logs_all += "\n" + f"Python packages unavailable in environment: {missing_pkgs}" + break + + properties["code"] = code_block.code + + task = asyncio.create_task( + client.post( + url, + headers=headers, + json={"properties": properties}, + ) + ) + + cancellation_token.link_future(task) + try: + response = await task + response.raise_for_status() + data = await response.json() + data = data["properties"] + logs_all += data.get("stderr", "") + data.get("stdout", "") + if "Success" in data["status"]: + if not self._suppress_result_output: + logs_all += str(data["result"]) + elif "Failure" in data["status"]: + exitcode = 1 + + except asyncio.TimeoutError as e: + logs_all += "\n Timeout" + # e.add_note is only in py 3.11+ + raise asyncio.TimeoutError(logs_all) from e + except asyncio.CancelledError as e: + logs_all += "\n Cancelled" + # e.add_note is only in py 3.11+ + raise asyncio.CancelledError(logs_all) from e + except aiohttp.ClientResponseError as e: + logs_all += "\nError while sending code block to endpoint" + raise ConnectionError(logs_all) from e + + return CodeResult(exit_code=exitcode, output=logs_all) + + async def restart(self) -> None: + """(Experimental) Restart the code executor. + + Resets the internal state of the executor by generating a new session ID and resetting the setup variables. + This causes the next code execution to reinitialize the environment and re-run any setup code. + """ + self._session_id = str(uuid4()) + self._setup_functions_complete = False + self._access_token = None + self._available_packages = None + self._setup_cwd_complete = False + + async def start(self) -> None: + """(Experimental) Start the code executor. + + Marks the code executor as started.""" + # No setup needed for this executor + self._started = True + + async def stop(self) -> None: + """(Experimental) Stop the code executor. + + Stops the code executor after cleaning up the temporary working directory (if it was created).""" + if self._temp_dir is not None: + self._temp_dir.cleanup() + self._temp_dir = None + self._started = False diff --git a/agent_dhal/agentdhal_extensions/code_executors/docker/__init__.py b/agent_dhal/agentdhal_extensions/code_executors/docker/__init__.py new file mode 100644 index 0000000..4241843 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/code_executors/docker/__init__.py @@ -0,0 +1,3 @@ +from ._docker_code_executor import DockerCommandLineCodeExecutor + +__all__ = ["DockerCommandLineCodeExecutor"] diff --git a/agent_dhal/agentdhal_extensions/code_executors/docker/_docker_code_executor.py b/agent_dhal/agentdhal_extensions/code_executors/docker/_docker_code_executor.py new file mode 100644 index 0000000..f749cd0 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/code_executors/docker/_docker_code_executor.py @@ -0,0 +1,613 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/docker_commandline_code_executor.py +# Credit to original authors + +from __future__ import annotations + +import asyncio +import logging +import shlex +import sys +import tempfile +import uuid +import warnings +from collections.abc import Sequence +from concurrent.futures import Future as ConcurrentFuture +from hashlib import sha256 +from pathlib import Path +from typing import Any, Callable, ClassVar, Dict, List, Optional, ParamSpec, Tuple, Union + +from agentdhal_core import CancellationToken, Component +from agentdhal_core.code_executor import ( + CodeBlock, + CodeExecutor, + FunctionWithRequirements, + FunctionWithRequirementsStr, +) +from pydantic import BaseModel +from typing_extensions import Self + +from docker.types import DeviceRequest + +from .._common import ( + CommandLineCodeResult, + build_python_functions_file, + get_file_name_from_content, + lang_to_cmd, + silence_pip, +) + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +try: + import asyncio_atexit + + import docker + from docker.errors import DockerException, ImageNotFound, NotFound + from docker.models.containers import Container +except ImportError as e: + raise RuntimeError( + "Missing dependecies for DockerCommandLineCodeExecutor. Please ensure the autogen-ext package was installed with the 'docker' extra." + ) from e + + +async def _wait_for_ready(container: Any, timeout: int = 60, stop_time: float = 0.1) -> None: + elapsed_time = 0.0 + while container.status != "running" and elapsed_time < timeout: + await asyncio.sleep(stop_time) + elapsed_time += stop_time + await asyncio.to_thread(container.reload) + continue + if container.status != "running": + raise ValueError("Container failed to start") + + +A = ParamSpec("A") + + +class DockerCommandLineCodeExecutorConfig(BaseModel): + """Configuration for DockerCommandLineCodeExecutor""" + + image: str = "python:3-slim" + container_name: Optional[str] = None + timeout: int = 60 + work_dir: Optional[str] = None + bind_dir: Optional[str] = None + auto_remove: bool = True + stop_container: bool = True + functions_module: str = "functions" + extra_volumes: Dict[str, Dict[str, str]] = {} + extra_hosts: Dict[str, str] = {} + init_command: Optional[str] = None + delete_tmp_files: bool = False + + +class DockerCommandLineCodeExecutor(CodeExecutor, Component[DockerCommandLineCodeExecutorConfig]): + """Executes code through a command line environment in a Docker container. + + .. note:: + + This class requires the :code:`docker` extra for the :code:`autogen-ext` package: + + .. code-block:: bash + + pip install "agentdhal-ext[docker]" + + + The executor first saves each code block in a file in the working + directory, and then executes the code file in the container. + The executor executes the code blocks in the order they are received. + Currently, the executor only supports Python and shell scripts. + For Python code, use the language "python" for the code block. + For shell scripts, use the language "bash", "shell", "sh", "pwsh", "powershell", or "ps1" for the code block. + + Args: + image (_type_, optional): Docker image to use for code execution. + Defaults to "python:3-slim". + container_name (Optional[str], optional): Name of the Docker container + which is created. If None, will autogenerate a name. Defaults to None. + timeout (int, optional): The timeout for code execution. Defaults to 60. + work_dir (Union[Path, str], optional): The working directory for the code + execution. Defaults to temporary directory. + bind_dir (Union[Path, str], optional): The directory that will be bound + to the code executor container. Useful for cases where you want to spawn + the container from within a container. Defaults to work_dir. + auto_remove (bool, optional): If true, will automatically remove the Docker + container when it is stopped. Defaults to True. + stop_container (bool, optional): If true, will automatically stop the + container when stop is called, when the context manager exits or when + the Python process exits with atext. Defaults to True. + device_requests (Optional[List[DeviceRequest]], optional): A list of device request instances to add to the container for exposing GPUs (e.g., [docker.types.DeviceRequest(count=-1, capabilities=[['gpu']])]). Defaults to None. + functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list. + functions_module (str, optional): The name of the module that will be created to store the functions. Defaults to "functions". + extra_volumes (Optional[Dict[str, Dict[str, str]]], optional): A dictionary of extra volumes (beyond the work_dir) to mount to the container; + key is host source path and value 'bind' is the container path. See Defaults to None. + Example: extra_volumes = {'/home/user1/': {'bind': '/mnt/vol2', 'mode': 'rw'}, '/var/www': {'bind': '/mnt/vol1', 'mode': 'ro'}} + extra_hosts (Optional[Dict[str, str]], optional): A dictionary of host mappings to add to the container. (See Docker docs on extra_hosts) Defaults to None. + Example: extra_hosts = {"kubernetes.docker.internal": "host-gateway"} + init_command (Optional[str], optional): A shell command to run before each shell operation execution. Defaults to None. + Example: init_command="kubectl config use-context docker-hub" + delete_tmp_files (bool, optional): If true, will delete temporary files after execution. Defaults to False. + + .. note:: + Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning. + + """ + + component_config_schema = DockerCommandLineCodeExecutorConfig + component_provider_override = "agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor" + + SUPPORTED_LANGUAGES: ClassVar[List[str]] = [ + "bash", + "shell", + "sh", + "pwsh", + "powershell", + "ps1", + "python", + ] + + FUNCTION_PROMPT_TEMPLATE: ClassVar[ + str + ] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names. + +For example, if there was a function called `foo` you could import it by writing `from $module_name import foo` + +$functions""" + + def __init__( + self, + image: str = "python:3-slim", + container_name: Optional[str] = None, + *, + timeout: int = 60, + work_dir: Union[Path, str, None] = None, + bind_dir: Optional[Union[Path, str]] = None, + auto_remove: bool = True, + stop_container: bool = True, + device_requests: Optional[List[DeviceRequest]] = None, + functions: Sequence[ + Union[ + FunctionWithRequirements[Any, A], + Callable[..., Any], + FunctionWithRequirementsStr, + ] + ] = [], + functions_module: str = "functions", + extra_volumes: Optional[Dict[str, Dict[str, str]]] = None, + extra_hosts: Optional[Dict[str, str]] = None, + init_command: Optional[str] = None, + delete_tmp_files: bool = False, + ): + if timeout < 1: + raise ValueError("Timeout must be greater than or equal to 1.") + + # Handle working directory logic + if work_dir is None: + self._work_dir = None + else: + if isinstance(work_dir, str): + work_dir = Path(work_dir) + # Emit a deprecation warning if the user is using the current directory as working directory + if work_dir.resolve() == Path.cwd().resolve(): + warnings.warn( + "Using the current directory as work_dir is deprecated.", + DeprecationWarning, + stacklevel=2, + ) + self._work_dir = work_dir + # Create the working directory if it doesn't exist + self._work_dir.mkdir(exist_ok=True, parents=True) + + if container_name is None: + self.container_name = f"agentdhal-code-exec-{uuid.uuid4()}" + else: + self.container_name = container_name + + self._timeout = timeout + + # Handle bind_dir + self._bind_dir: Optional[Path] = None + if bind_dir is not None: + self._bind_dir = Path(bind_dir) if isinstance(bind_dir, str) else bind_dir + else: + self._bind_dir = self._work_dir # Default to work_dir if not provided + + # Track temporary directory + self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None + self._temp_dir_path: Optional[Path] = None + + self._started = False + + self._auto_remove = auto_remove + self._stop_container = stop_container + self._image = image + + if not functions_module.isidentifier(): + raise ValueError("Module name must be a valid Python identifier") + + self._functions_module = functions_module + self._functions = functions + self._extra_volumes = extra_volumes if extra_volumes is not None else {} + self._extra_hosts = extra_hosts if extra_hosts is not None else {} + self._init_command = init_command + self._delete_tmp_files = delete_tmp_files + self._device_requests = device_requests + + # Setup could take some time so we intentionally wait for the first code block to do it. + if len(functions) > 0: + self._setup_functions_complete = False + else: + self._setup_functions_complete = True + + self._container: Container | None = None + self._running = False + + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._cancellation_futures: List[ConcurrentFuture[None]] = [] + + @property + def timeout(self) -> int: + """(Experimental) The timeout for code execution.""" + return self._timeout + + async def _setup_functions(self, cancellation_token: CancellationToken) -> None: + func_file_content = build_python_functions_file(self._functions) + func_file = self.work_dir / f"{self._functions_module}.py" + func_file.write_text(func_file_content) + + # Collect requirements + lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)] + flattened_packages = [item for sublist in lists_of_packages for item in sublist] + required_packages = list(set(flattened_packages)) + if len(required_packages) > 0: + logging.info("Ensuring packages are installed in executor.") + + packages = shlex.join(required_packages) + + result = await self._execute_code_dont_check_setup( + [CodeBlock(code=f"python -m pip install {packages}", language="sh")], cancellation_token + ) + + if result.exit_code != 0: + stdout = result.output + stderr = result.output + raise ValueError(f"Pip install failed. {stdout}, {stderr}") + + # Attempt to load the function file to check for syntax errors, imports etc. + exec_result = await self._execute_code_dont_check_setup( + [CodeBlock(code=func_file_content, language="python")], cancellation_token + ) + + if exec_result.exit_code != 0: + raise ValueError(f"Functions failed to load: {exec_result.output}") + + self._setup_functions_complete = True + + async def _kill_running_command(self, command: List[str]) -> None: + if self._container is None or not self._running: + return + await asyncio.to_thread(self._container.exec_run, ["pkill", "-f", " ".join(command)]) + + async def _execute_command(self, command: List[str], cancellation_token: CancellationToken) -> Tuple[str, int]: + if self._container is None or not self._running: + raise ValueError("Container is not running. Must first be started with either start or a context manager.") + + exec_task = asyncio.create_task(asyncio.to_thread(self._container.exec_run, command)) + cancellation_token.link_future(exec_task) + + # Wait for the exec task to finish. + try: + result = await exec_task + exit_code = result.exit_code + output = result.output.decode("utf-8") + if exit_code == 124: + output += "\n Timeout" + return output, exit_code + except asyncio.CancelledError: + # Schedule a task to kill the running command in the background. + if self._loop and not self._loop.is_closed(): + try: + logging.debug(f"Scheduling kill command via run_coroutine_threadsafe on loop {self._loop!r}") + future: ConcurrentFuture[None] = asyncio.run_coroutine_threadsafe( + self._kill_running_command(command), self._loop + ) + self._cancellation_futures.append(future) + logging.debug(f"Kill command scheduled, future: {future!r}") + except RuntimeError as e: + logging.error(f"Failed to schedule kill command on loop {self._loop!r}: {e}") + except Exception as e: + logging.exception(f"Unexpected error scheduling kill command: {e}") + else: + logging.warning( + f"Cannot schedule kill command: Executor loop is not available or closed (loop: {self._loop!r})." + ) + return "Code execution was cancelled.", 1 + + async def _execute_code_dont_check_setup( + self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken + ) -> CommandLineCodeResult: + if self._container is None or not self._running: + raise ValueError("Container is not running. Must first be started with either start or a context manager.") + + if len(code_blocks) == 0: + raise ValueError("No code blocks to execute.") + + outputs: List[str] = [] + files: List[Path] = [] + last_exit_code = 0 + try: + for code_block in code_blocks: + lang = code_block.language.lower() + code = silence_pip(code_block.code, lang) + + # Check if there is a filename comment + try: + filename = get_file_name_from_content(code, self.work_dir) + except ValueError: + outputs.append("Filename is not in the workspace") + last_exit_code = 1 + break + + if not filename: + filename = f"tmp_code_{sha256(code.encode()).hexdigest()}.{lang}" + + code_path = self.work_dir / filename + with code_path.open("w", encoding="utf-8") as fout: + fout.write(code) + files.append(code_path) + + command = ["timeout", str(self._timeout), lang_to_cmd(lang), filename] + + output, exit_code = await self._execute_command(command, cancellation_token) + outputs.append(output) + last_exit_code = exit_code + if exit_code != 0: + break + finally: + if self._delete_tmp_files: + for file in files: + try: + file.unlink() + except (OSError, FileNotFoundError): + pass + + code_file = str(files[0]) if files else None + return CommandLineCodeResult(exit_code=last_exit_code, output="".join(outputs), code_file=code_file) + + @property + def work_dir(self) -> Path: + # If a user specifies a working directory, use that + if self._work_dir is not None: + # If a user specifies the current directory, warn them that this is deprecated + if self._work_dir == Path("."): + warnings.warn( + "Using the current directory as work_dir is deprecated.", + DeprecationWarning, + stacklevel=2, + ) + return self._work_dir + # If a user does not specify a working directory, use the default directory (tempfile.TemporaryDirectory) + elif self._temp_dir is not None: + return Path(self._temp_dir.name) + else: + raise RuntimeError("Working directory not properly initialized") + + @property + def bind_dir(self) -> Path: + # If the user specified a bind directory, return it + if self._bind_dir is not None: + return self._bind_dir + # Otherwise bind_dir is set to the current work_dir as default + else: + return self.work_dir + + async def execute_code_blocks( + self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken + ) -> CommandLineCodeResult: + """(Experimental) Execute the code blocks and return the result. + + Args: + code_blocks (List[CodeBlock]): The code blocks to execute. + + Returns: + CommandlineCodeResult: The result of the code execution.""" + + if not self._setup_functions_complete: + await self._setup_functions(cancellation_token) + + return await self._execute_code_dont_check_setup(code_blocks, cancellation_token) + + async def restart(self) -> None: + """(Experimental) Restart the Docker container code executor.""" + if self._container is None or not self._running: + raise ValueError("Container is not running. Must first be started with either start or a context manager.") + + await asyncio.to_thread(self._container.restart) # type: ignore + if self._container.status != "running": + self._running = False + logs_str = self._container.logs().decode("utf-8") + raise ValueError(f"Failed to restart container. Logs: {logs_str}") + + async def stop(self) -> None: + """(Experimental) Stop the code executor. + + Stops the Docker container and cleans up any temporary files (if they were created), along with the temporary directory. + The method first waits for all cancellation tasks to finish before stopping the container. Finally it marks the executor as not running. + If the container is not running, the method does nothing. + """ + if not self._running: + return + + if self._temp_dir is not None: + self._temp_dir.cleanup() + self._temp_dir = None + + client = docker.from_env() + try: + try: + container = await asyncio.to_thread(client.containers.get, self.container_name) + except NotFound: + logging.debug(f"Container {self.container_name} not found during stop...") + self._running = False + self._cancellation_futures.clear() + return + + if self._cancellation_futures: + if not self._loop or self._loop.is_closed(): + logging.warning( + f"Executor loop ({self._loop!r}) is closed or unavailable. Cannot reliably wait for " + f"{len(self._cancellation_futures)} cancellation futures." + ) + self._cancellation_futures.clear() + else: + # concurrent.futures.Future -> asyncio.Future + asyncio_futures = [asyncio.wrap_future(f, loop=self._loop) for f in self._cancellation_futures] + + if asyncio_futures: + logging.debug( + f"Waiting for {len(asyncio_futures)} cancellation futures to complete on loop {self._loop!r}..." + ) + results = await asyncio.gather(*asyncio_futures, return_exceptions=True) + for i, result in enumerate(results): + original_future = self._cancellation_futures[i] + if isinstance(result, Exception): + logging.warning(f"Cancellation future {original_future!r} failed: {result}") + else: + logging.debug(f"Cancellation future {original_future!r} completed successfully.") + else: + logging.debug("No valid cancellation futures to await.") + + self._cancellation_futures.clear() + + logging.debug(f"Stopping container {self.container_name}...") + await asyncio.to_thread(container.stop) + logging.debug(f"Container {self.container_name} stopped.") + + except DockerException as e: + logging.error(f"Docker error while stopping container {self.container_name}: {e}") + except Exception as e: + logging.exception(f"Unexpected error during stop operation for container {self.container_name}: {e}") + finally: + self._running = False + self._cancellation_futures.clear() + + async def start(self) -> None: + """(Experimental) Start the code executor. + + This method sets the working environment variables, connects to Docker and starts the code executor. + If no working directory was provided to the code executor, it creates a temporary directory and sets it as the code executor working directory. + """ + + if self._work_dir is None and self._temp_dir is None: + self._temp_dir = tempfile.TemporaryDirectory() + self._temp_dir_path = Path(self._temp_dir.name) + self._temp_dir_path.mkdir(exist_ok=True) + + # Start a container from the image, read to exec commands later + try: + client = docker.from_env() + except DockerException as e: + if "FileNotFoundError" in str(e): + raise RuntimeError("Failed to connect to Docker. Please ensure Docker is installed and running.") from e + raise + except Exception as e: + raise RuntimeError(f"Unexpected error while connecting to Docker: {str(e)}") from e + + # Check if the image exists + try: + await asyncio.to_thread(client.images.get, self._image) + except ImageNotFound: + # TODO logger + logging.info(f"Pulling image {self._image}...") + # Let the docker exception escape if this fails. + await asyncio.to_thread(client.images.pull, self._image) + + # Prepare the command (if needed) + shell_command = "/bin/sh" + command = ["-c", f"{(self._init_command)};exec {shell_command}"] if self._init_command else None + + # Check if a container with the same name already exists and remove it + try: + existing_container = await asyncio.to_thread(client.containers.get, self.container_name) + await asyncio.to_thread(existing_container.remove, force=True) + except NotFound: + pass + + self._container = await asyncio.to_thread( + client.containers.create, + self._image, + name=self.container_name, + entrypoint=shell_command, + command=command, + tty=True, + detach=True, + auto_remove=self._auto_remove, + volumes={str(self.bind_dir.resolve()): {"bind": "/workspace", "mode": "rw"}, **self._extra_volumes}, + working_dir="/workspace", + extra_hosts=self._extra_hosts, + device_requests=self._device_requests, + ) + await asyncio.to_thread(self._container.start) + + await _wait_for_ready(self._container) + + async def cleanup() -> None: + await self.stop() + asyncio_atexit.unregister(cleanup) # type: ignore + + if self._stop_container: + asyncio_atexit.register(cleanup) # type: ignore + + # Check if the container is running + if self._container.status != "running": + logs_str = self._container.logs().decode("utf-8") + raise ValueError(f"Failed to start container from image {self._image}. Logs: {logs_str}") + + self._loop = asyncio.get_running_loop() + self._cancellation_futures = [] + logging.debug(f"Executor started, associated with event loop: {self._loop!r}") + + self._running = True + + def _to_config(self) -> DockerCommandLineCodeExecutorConfig: + """(Experimental) Convert the component to a config object.""" + if self._functions: + logging.info("Functions will not be included in serialized configuration") + + return DockerCommandLineCodeExecutorConfig( + image=self._image, + container_name=self.container_name, + timeout=self._timeout, + work_dir=str(self._work_dir) if self._work_dir else None, + bind_dir=str(self._bind_dir) if self._bind_dir else None, + auto_remove=self._auto_remove, + stop_container=self._stop_container, + functions_module=self._functions_module, + extra_volumes=self._extra_volumes, + extra_hosts=self._extra_hosts, + init_command=self._init_command, + delete_tmp_files=self._delete_tmp_files, + ) + + @classmethod + def _from_config(cls, config: DockerCommandLineCodeExecutorConfig) -> Self: + """(Experimental) Create a component from a config object.""" + + return cls( + image=config.image, + container_name=config.container_name, + timeout=config.timeout, + work_dir=Path(config.work_dir) if config.work_dir else None, + bind_dir=Path(config.bind_dir) if config.bind_dir else None, + auto_remove=config.auto_remove, + stop_container=config.stop_container, + functions=[], # Functions not restored from config + functions_module=config.functions_module, + extra_volumes=config.extra_volumes, + extra_hosts=config.extra_hosts, + init_command=config.init_command, + delete_tmp_files=config.delete_tmp_files, + ) diff --git a/agent_dhal/agentdhal_extensions/code_executors/docker_jupyter/__init__.py b/agent_dhal/agentdhal_extensions/code_executors/docker_jupyter/__init__.py new file mode 100644 index 0000000..549c178 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/code_executors/docker_jupyter/__init__.py @@ -0,0 +1,10 @@ +from ._docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterCodeResult +from ._jupyter_server import DockerJupyterServer, JupyterClient, JupyterKernelClient + +__all__ = [ + "DockerJupyterCodeExecutor", + "DockerJupyterServer", + "JupyterClient", + "JupyterKernelClient", + "DockerJupyterCodeResult", +] diff --git a/agent_dhal/agentdhal_extensions/code_executors/docker_jupyter/_docker_jupyter.py b/agent_dhal/agentdhal_extensions/code_executors/docker_jupyter/_docker_jupyter.py new file mode 100644 index 0000000..33a5bbd --- /dev/null +++ b/agent_dhal/agentdhal_extensions/code_executors/docker_jupyter/_docker_jupyter.py @@ -0,0 +1,300 @@ +import asyncio +import base64 +import json +import os +import tempfile +import uuid +from dataclasses import dataclass +from pathlib import Path +from types import TracebackType +from typing import List, Optional, Union + +from agentdhal_core import CancellationToken, Component +from agentdhal_core.code_executor import CodeBlock, CodeExecutor, CodeResult +from agentdhal_extensions.code_executors._common import silence_pip +from pydantic import BaseModel +from typing_extensions import Self + +from ._jupyter_server import JupyterClient, JupyterConnectable, JupyterConnectionInfo, JupyterKernelClient + + +@dataclass +class DockerJupyterCodeResult(CodeResult): + """(Experimental) A code result class for IPython code executor.""" + + output_files: list[Path] + + +class DockerJupyterCodeExecutorConfig(BaseModel): + """Configuration for JupyterCodeExecutor""" + + jupyter_server: Union[JupyterConnectable, JupyterConnectionInfo] + kernel_name: str = "python3" + timeout: int = 60 + output_dir: Optional[Union[Path, str]] = None + + class Config: + arbitrary_types_allowed = True + + +class DockerJupyterCodeExecutor(CodeExecutor, Component[DockerJupyterCodeExecutorConfig]): + """(Experimental) A code executor class that executes code statefully using + a Jupyter server supplied to this class. + + Each execution is stateful and can access variables created from previous + executions in the same session. + + To use this, you need to install the following dependencies: + + .. code-block:: shell + + pip install "agentdhal-ext[docker-jupyter-executor]" + + Args: + jupyter_server (Union[JupyterConnectable, JupyterConnectionInfo]): The Jupyter server to use. + kernel_name (str): The kernel name to use. Make sure it is installed. + By default, it is "python3". + timeout (int): The timeout for code execution, by default 60. + output_dir (str): The directory to save output files, by default None. + + Example of using it directly: + + .. code-block:: python + + import asyncio + from agentdhal_core import CancellationToken + from agentdhal_core.code_executor import CodeBlock + from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer + + + async def main() -> None: + async with DockerJupyterServer() as jupyter_server: + async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor: + code_blocks = [CodeBlock(code="print('hello world!')", language="python")] + code_result = await executor.execute_code_blocks(code_blocks, cancellation_token=CancellationToken()) + print(code_result) + + + asyncio.run(main()) + + Example of using it with your own jupyter image: + + .. code-block:: python + + import asyncio + from agentdhal_core import CancellationToken + from agentdhal_core.code_executor import CodeBlock + from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer + + + async def main() -> None: + async with DockerJupyterServer(custom_image_name="your_custom_images_name", expose_port=8888) as jupyter_server: + async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor: + code_blocks = [CodeBlock(code="print('hello world!')", language="python")] + code_result = await executor.execute_code_blocks(code_blocks, cancellation_token=CancellationToken()) + print(code_result) + + + asyncio.run(main()) + + Example of using it with :class:`~agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool`: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.code_execution import PythonCodeExecutionTool + + + async def main() -> None: + async with DockerJupyterServer() as jupyter_server: + async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor: + tool = PythonCodeExecutionTool(executor) + model_client = OpenAIChatCompletionClient(model="gpt-4o") + agent = AssistantAgent("assistant", model_client=model_client, tools=[tool]) + result = await agent.run(task="What is the 10th Fibonacci number? Use Python to calculate it.") + print(result) + + + asyncio.run(main()) + + Example of using it inside a :class:`~agentdhal_agentchat.agents._code_executor_agent.CodeExecutorAgent`: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import CodeExecutorAgent + from agentdhal_agentchat.messages import TextMessage + from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer + from agentdhal_core import CancellationToken + + + async def main() -> None: + async with DockerJupyterServer() as jupyter_server: + async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor: + code_executor_agent = CodeExecutorAgent("code_executor", code_executor=executor) + task = TextMessage( + content='''Here is some code + ```python + print('Hello world') + ``` + ''', + source="user", + ) + response = await code_executor_agent.on_messages([task], CancellationToken()) + print(response.chat_message) + + + asyncio.run(main()) + + """ + + component_config_schema = DockerJupyterCodeExecutorConfig + component_provider_override = "agentdhal_extensions.code_executors.docker_jupyter.DockerJupyterCodeExecutor" + + def __init__( + self, + jupyter_server: Union[JupyterConnectable, JupyterConnectionInfo], + kernel_name: str = "python3", + timeout: int = 60, + output_dir: Path | None = None, + ): + if timeout < 1: + raise ValueError("Timeout must be greater than or equal to 1.") + + if isinstance(jupyter_server, JupyterConnectable): + self._connection_info = jupyter_server.connection_info + elif isinstance(jupyter_server, JupyterConnectionInfo): + self._connection_info = jupyter_server + else: + raise ValueError("jupyter_server must be a JupyterConnectable or JupyterConnectionInfo.") + + self._output_dir = output_dir or getattr(jupyter_server, "_bind_dir", None) + if not self._output_dir: + with tempfile.TemporaryDirectory() as temp_dir: + self._output_dir = Path(temp_dir) + self._output_dir.mkdir(exist_ok=True) + + self._jupyter_client = JupyterClient(self._connection_info) + + self._kernel_name = kernel_name + self._timeout = timeout + self._async_jupyter_kernel_client: Optional[JupyterKernelClient] = None + self._kernel_id: Optional[str] = None + + async def _ensure_async_kernel_client(self) -> JupyterKernelClient: + """Ensure that an async kernel client exists and return it.""" + if self._kernel_id is None: + await self.start() + assert self._kernel_id is not None + if self._async_jupyter_kernel_client is None: + self._async_jupyter_kernel_client = await self._jupyter_client.get_kernel_client(self._kernel_id) + return self._async_jupyter_kernel_client + + async def execute_code_blocks( + self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken + ) -> DockerJupyterCodeResult: + """(Experimental) Execute a list of code blocks and return the result. + + This method executes a list of code blocks as cells in the Jupyter kernel. + See: https://jupyter-client.readthedocs.io/en/stable/messaging.html + for the message protocol. + + Args: + code_blocks (List[CodeBlock]): A list of code blocks to execute. + + Returns: + DockerJupyterCodeResult: The result of the code execution. + """ + kernel_client = await self._ensure_async_kernel_client() + # Wait for kernel to be ready using async client + is_ready = await kernel_client.wait_for_ready(timeout_seconds=self._timeout) + if not is_ready: + return DockerJupyterCodeResult(exit_code=1, output="ERROR: Kernel not ready", output_files=[]) + + outputs: List[str] = [] + output_files: List[Path] = [] + for code_block in code_blocks: + code = silence_pip(code_block.code, code_block.language) + # Execute code using async client + exec_task = asyncio.create_task(kernel_client.execute(code, timeout_seconds=self._timeout)) + cancellation_token.link_future(exec_task) + result = await exec_task + if result.is_ok: + outputs.append(result.output) + for data in result.data_items: + if data.mime_type == "image/png": + path = self._save_image(data.data) + outputs.append(path) + output_files.append(Path(path)) + elif data.mime_type == "text/html": + path = self._save_html(data.data) + outputs.append(path) + output_files.append(Path(path)) + else: + outputs.append(json.dumps(data.data)) + else: + existing_output = "\n".join([str(output) for output in outputs]) + return DockerJupyterCodeResult( + exit_code=1, output=existing_output + "\nERROR: " + result.output, output_files=output_files + ) + return DockerJupyterCodeResult( + exit_code=0, output="\n".join([str(output) for output in outputs]), output_files=output_files + ) + + async def restart(self) -> None: + """(Experimental) Restart a new session.""" + # Use async client to restart kernel + if self._kernel_id is not None: + await self._jupyter_client.restart_kernel(self._kernel_id) + # Reset the clients to force recreation + if self._async_jupyter_kernel_client is not None: + await self._async_jupyter_kernel_client.stop() + self._async_jupyter_kernel_client = None + + async def start(self) -> None: + """(Experimental) Start a new session.""" + available_kernels = await self._jupyter_client.list_kernel_specs() + if self._kernel_name not in available_kernels["kernelspecs"]: + raise ValueError(f"Kernel {self._kernel_name} is not installed.") + self._kernel_id = await self._jupyter_client.start_kernel(self._kernel_name) + + def _save_image(self, image_data_base64: str) -> str: + """Save image data to a file.""" + image_data = base64.b64decode(image_data_base64) + filename = f"{uuid.uuid4().hex}.png" + path = os.path.join(str(self._output_dir), filename) + with open(path, "wb") as f: + f.write(image_data) + return os.path.abspath(path) + + def _save_html(self, html_data: str) -> str: + """Save html data to a file.""" + filename = f"{uuid.uuid4().hex}.html" + path = os.path.join(str(self._output_dir), filename) + with open(path, "w") as f: + f.write(html_data) + return os.path.abspath(path) + + async def stop(self) -> None: + """Stop the kernel.""" + if self._kernel_id is not None: + await self._jupyter_client.delete_kernel(self._kernel_id) + if self._async_jupyter_kernel_client is not None: + await self._async_jupyter_kernel_client.stop() + self._async_jupyter_kernel_client = None + await self._jupyter_client.close() + + async def __aenter__(self) -> Self: + await self.start() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.stop() diff --git a/agent_dhal/agentdhal_extensions/code_executors/docker_jupyter/_jupyter_server.py b/agent_dhal/agentdhal_extensions/code_executors/docker_jupyter/_jupyter_server.py new file mode 100644 index 0000000..0655e7b --- /dev/null +++ b/agent_dhal/agentdhal_extensions/code_executors/docker_jupyter/_jupyter_server.py @@ -0,0 +1,430 @@ +import asyncio +import atexit +import datetime +import io +import json +import logging +import os +import secrets +import uuid +from dataclasses import dataclass +from pathlib import Path +from time import sleep +from types import TracebackType +from typing import Any, Dict, List, Optional, Protocol, Type, Union, cast, runtime_checkable + +import aiohttp +import docker +import docker.errors +import requests +import websockets +from requests.adapters import HTTPAdapter, Retry +from typing_extensions import Self + + +@dataclass +class JupyterConnectionInfo: + """(Experimental)""" + + host: str + """`str` - Host of the Jupyter gateway server""" + use_https: bool + """`bool` - Whether to use HTTPS""" + port: Optional[int] = None + """`Optional[int]` - Port of the Jupyter gateway server. If None, the default port is used""" + token: Optional[str] = None + """`Optional[str]` - Token for authentication. If None, no token is used""" + + +@runtime_checkable +class JupyterConnectable(Protocol): + """(Experimental)""" + + @property + def connection_info(self) -> JupyterConnectionInfo: + """Return the connection information for this connectable.""" + ... + + +class JupyterClient: + def __init__(self, connection_info: JupyterConnectionInfo): + """(Experimental) A client for communicating with a Jupyter gateway server. + + Args: + connection_info (JupyterConnectionInfo): Connection information + """ + self._connection_info = connection_info + self._session = requests.Session() + retries = Retry(total=5, backoff_factor=0.1) + self._session.mount("http://", HTTPAdapter(max_retries=retries)) + # Create aiohttp session for async requests + self._async_session: aiohttp.ClientSession | None = None + + async def _ensure_async_session(self) -> aiohttp.ClientSession: + if self._async_session is None: + self._async_session = aiohttp.ClientSession() + return self._async_session + + def _get_headers(self) -> Dict[str, str]: + if self._connection_info.token is None: + return {} + return {"Authorization": f"token {self._connection_info.token}"} + + def _get_api_base_url(self) -> str: + protocol = "https" if self._connection_info.use_https else "http" + port = f":{self._connection_info.port}" if self._connection_info.port else "" + return f"{protocol}://{self._connection_info.host}{port}" + + def _get_ws_base_url(self) -> str: + port = f":{self._connection_info.port}" if self._connection_info.port else "" + return f"ws://{self._connection_info.host}{port}" + + async def list_kernel_specs(self) -> Dict[str, Dict[str, str]]: + response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers()) + return cast(Dict[str, Dict[str, str]], response.json()) + + async def list_kernels(self) -> List[Dict[str, str]]: + response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers()) + return cast(List[Dict[str, str]], response.json()) + + async def start_kernel(self, kernel_spec_name: str) -> str: + """Start a new kernel asynchronously. + + Args: + kernel_spec_name (str): Name of the kernel spec to start + + Returns: + str: ID of the started kernel + """ + session = await self._ensure_async_session() + async with session.post( + f"{self._get_api_base_url()}/api/kernels", + headers=self._get_headers(), + json={"name": kernel_spec_name}, + ) as response: + data = await response.json() + return cast(str, data["id"]) + + async def delete_kernel(self, kernel_id: str) -> None: + session = await self._ensure_async_session() + async with session.delete( + f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers() + ) as response: + response.raise_for_status() + + async def restart_kernel(self, kernel_id: str) -> None: + session = await self._ensure_async_session() + async with session.post( + f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers() + ) as response: + response.raise_for_status() + + async def get_kernel_client(self, kernel_id: str) -> "JupyterKernelClient": + ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels" + # Using websockets library for async websocket connections + ws = await websockets.connect(ws_url, additional_headers=self._get_headers()) + return JupyterKernelClient(ws) + + async def close(self) -> None: + """Close the async session""" + if self._async_session is not None: + await self._async_session.close() + self._async_session = None + self._session.close() + + +@dataclass +class DataItem: + mime_type: str + data: str + + +@dataclass +class ExecutionResult: + is_ok: bool + output: str + data_items: List[DataItem] + + +class JupyterKernelClient: + """An asynchronous client for communicating with a Jupyter kernel.""" + + def __init__(self, websocket: websockets.ClientConnection) -> None: + self._session_id = uuid.uuid4().hex + self._websocket = websocket + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: + await self.stop() + + async def stop(self) -> None: + await self._websocket.close() + + async def _send_message(self, *, content: Dict[str, Any], channel: str, message_type: str) -> str: + timestamp = datetime.datetime.now().isoformat() + message_id = uuid.uuid4().hex + message = { + "header": { + "username": "agentdhal", + "version": "5.0", + "session": self._session_id, + "msg_id": message_id, + "msg_type": message_type, + "date": timestamp, + }, + "parent_header": {}, + "channel": channel, + "content": content, + "metadata": {}, + "buffers": {}, + } + await self._websocket.send(json.dumps(message)) + return message_id + + async def _receive_message(self, timeout_seconds: Optional[float]) -> Optional[Dict[str, Any]]: + try: + if timeout_seconds is not None: + data = await asyncio.wait_for(self._websocket.recv(), timeout=timeout_seconds) + else: + data = await self._websocket.recv() + if isinstance(data, bytes): + return cast(Dict[str, Any], json.loads(data.decode("utf-8"))) + return cast(Dict[str, Any], json.loads(data)) + except asyncio.TimeoutError: + return None + + async def wait_for_ready(self, timeout_seconds: Optional[float] = None) -> bool: + message_id = await self._send_message(content={}, channel="shell", message_type="kernel_info_request") + while True: + message = await self._receive_message(timeout_seconds) + # This means we timed out with no new messages. + if message is None: + return False + if ( + message.get("parent_header", {}).get("msg_id") == message_id + and message["msg_type"] == "kernel_info_reply" + ): + return True + + async def execute(self, code: str, timeout_seconds: Optional[float] = None) -> ExecutionResult: + message_id = await self._send_message( + content={ + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + channel="shell", + message_type="execute_request", + ) + + text_output: List[str] = [] + data_output: List[DataItem] = [] + while True: + message = await self._receive_message(timeout_seconds) + if message is None: + return ExecutionResult( + is_ok=False, output="ERROR: Timeout waiting for output from code block.", data_items=[] + ) + + # Ignore messages that are not for this execution. + if message.get("parent_header", {}).get("msg_id") != message_id: + continue + + msg_type = message["msg_type"] + content = message["content"] + if msg_type in ["execute_result", "display_data"]: + for data_type, data in content["data"].items(): + if data_type == "text/plain": + text_output.append(data) + elif data_type.startswith("image/") or data_type == "text/html": + data_output.append(DataItem(mime_type=data_type, data=data)) + else: + text_output.append(json.dumps(data)) + elif msg_type == "stream": + text_output.append(content["text"]) + elif msg_type == "error": + # Output is an error. + return ExecutionResult( + is_ok=False, + output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}", + data_items=[], + ) + if msg_type == "status" and content["execution_state"] == "idle": + break + return ExecutionResult( + is_ok=True, output="\n".join([str(output) for output in text_output]), data_items=data_output + ) + + +class DockerJupyterServer(JupyterConnectable): + DEFAULT_DOCKERFILE = """FROM quay.io/jupyter/docker-stacks-foundation + + SHELL ["/bin/bash", "-o", "pipefail", "-c"] + + USER ${NB_UID} + RUN mamba install --yes jupyter_kernel_gateway ipykernel && \ + mamba clean --all -f -y && \ + fix-permissions "${CONDA_DIR}" && \ + fix-permissions "/home/${NB_USER}" + + ENV TOKEN="UNSET" + CMD python -m jupyter kernelgateway --KernelGatewayApp.ip=0.0.0.0 \ + --KernelGatewayApp.port=8888 \ + --KernelGatewayApp.auth_token="${TOKEN}" \ + --JupyterApp.answer_yes=true \ + --JupyterWebsocketPersonality.list_kernels=true + + EXPOSE 8888 + + WORKDIR "${HOME}" + """ + + class GenerateToken: + pass + + def __init__( + self, + *, + custom_image_name: Optional[str] = None, + container_name: Optional[str] = None, + auto_remove: bool = True, + stop_container: bool = True, + docker_env: Optional[Dict[str, str]] = None, + expose_port: int = 8888, + token: Optional[Union[str, GenerateToken]] = None, + work_dir: Union[Path, str] = "/workspace", + bind_dir: Optional[Union[Path, str]] = None, + ): + """Start a Jupyter kernel gateway server in a Docker container. + + Args: + custom_image_name: Custom Docker image to use. If None, builds and uses bundled image. + container_name: Name for the Docker container. Auto-generated if None. + auto_remove: If True, container will be deleted when stopped. + stop_container: If True, container stops on program exit or when context manager exits. + docker_env: Additional environment variables for the container. + expose_port: Port to expose for Jupyter connection. + token: Authentication token. If GenerateToken, creates random token. Empty for no auth. + work_dir: Working directory inside the container. + bind_dir: Local directory to bind to container's work_dir. + """ + # Generate container name if not provided + container_name = container_name or f"agentdhal-jupyterkernelgateway-{uuid.uuid4()}" + + # Initialize Docker client + client = docker.from_env() + # Set up bind directory if specified + self._bind_dir: Optional[Path] = None + if bind_dir: + self._bind_dir = Path(bind_dir) if isinstance(bind_dir, str) else bind_dir + self._bind_dir.mkdir(exist_ok=True) + os.chmod(bind_dir, 0o777) + + # Determine and prepare Docker image + image_name = custom_image_name or "agentdhal-jupyterkernelgateway" + if not custom_image_name: + try: + client.images.get(image_name) + except docker.errors.ImageNotFound: + # Build default image if not found + here = Path(__file__).parent + dockerfile = io.BytesIO(self.DEFAULT_DOCKERFILE.encode("utf-8")) + logging.info(f"Building image {image_name}...") + client.images.build(path=str(here), fileobj=dockerfile, tag=image_name) + logging.info(f"Image {image_name} built successfully") + else: + # Verify custom image exists + try: + client.images.get(image_name) + except docker.errors.ImageNotFound as err: + raise ValueError(f"Custom image {image_name} does not exist") from err + if docker_env is None: + docker_env = {} + if token is None: + token = DockerJupyterServer.GenerateToken() + # Set up authentication token + self._token = secrets.token_hex(32) if isinstance(token, DockerJupyterServer.GenerateToken) else token + + # Prepare environment variables + env = {"TOKEN": self._token} + env.update(docker_env) + + # Define volume configuration if bind directory is specified + volumes = {str(self._bind_dir): {"bind": str(work_dir), "mode": "rw"}} if self._bind_dir else None + + # Start the container + container = client.containers.run( + image_name, + detach=True, + auto_remove=auto_remove, + environment=env, + publish_all_ports=True, + name=container_name, + volumes=volumes, + working_dir=str(work_dir), + ) + + # Wait for container to be ready + self._wait_for_ready(container) + + # Store container information + self._container = container + self._port = int(container.ports[f"{expose_port}/tcp"][0]["HostPort"]) + self._container_id = container.id + self._expose_port = expose_port + + if self._container_id is None: + raise ValueError("Failed to obtain container id.") + + # Define cleanup function + def cleanup() -> None: + try: + assert self._container_id is not None + inner_container = client.containers.get(self._container_id) + inner_container.stop() + except docker.errors.NotFound: + pass + atexit.unregister(cleanup) + + # Register cleanup if container should be stopped automatically + if stop_container: + atexit.register(cleanup) + + self._cleanup_func = cleanup + self._stop_container = stop_container + + @property + def connection_info(self) -> JupyterConnectionInfo: + return JupyterConnectionInfo(host="127.0.0.1", use_https=False, port=self._port, token=self._token) + + def _wait_for_ready(self, container: Any, timeout: int = 60, stop_time: float = 0.1) -> None: + elapsed_time = 0.0 + while container.status != "running" and elapsed_time < timeout: + sleep(stop_time) + elapsed_time += stop_time + container.reload() + continue + if container.status != "running": + raise ValueError("Container failed to start") + + async def stop(self) -> None: + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._cleanup_func) + + async def get_client(self) -> JupyterClient: + return JupyterClient(self.connection_info) + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: + await self.stop() diff --git a/agent_dhal/agentdhal_extensions/code_executors/jupyter/__init__.py b/agent_dhal/agentdhal_extensions/code_executors/jupyter/__init__.py new file mode 100644 index 0000000..1a6ba79 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/code_executors/jupyter/__init__.py @@ -0,0 +1,6 @@ +from ._jupyter_code_executor import JupyterCodeExecutor, JupyterCodeResult + +__all__ = [ + "JupyterCodeExecutor", + "JupyterCodeResult", +] diff --git a/agent_dhal/agentdhal_extensions/code_executors/jupyter/_jupyter_code_executor.py b/agent_dhal/agentdhal_extensions/code_executors/jupyter/_jupyter_code_executor.py new file mode 100644 index 0000000..c03688d --- /dev/null +++ b/agent_dhal/agentdhal_extensions/code_executors/jupyter/_jupyter_code_executor.py @@ -0,0 +1,335 @@ +import asyncio +import base64 +import json +import re +import sys +import tempfile +import uuid +import warnings +from dataclasses import dataclass +from pathlib import Path + +from agentdhal_core import Component +from pydantic import BaseModel + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +from contextlib import AbstractAsyncContextManager +from typing import Optional, Union + +from agentdhal_core import CancellationToken +from agentdhal_core.code_executor import CodeBlock, CodeExecutor, CodeResult +from nbclient import NotebookClient +from nbformat import NotebookNode +from nbformat import v4 as nbformat +from typing_extensions import Self + +from .._common import silence_pip + + +@dataclass +class JupyterCodeResult(CodeResult): + """A code result class for Jupyter code executor.""" + + output_files: list[Path] + + +class JupyterCodeExecutorConfig(BaseModel): + """Configuration for JupyterCodeExecutor""" + + kernel_name: str = "python3" + timeout: int = 60 + output_dir: Optional[str] = None + + +class JupyterCodeExecutor(CodeExecutor, Component[JupyterCodeExecutorConfig]): + """A code executor class that executes code statefully using [nbclient](https://github.com/jupyter/nbclient). + + .. danger:: + + This will execute code on the local machine. If being used with LLM generated code, caution should be used. + + Example of using it directly: + + .. code-block:: python + + import asyncio + from agentdhal_core import CancellationToken + from agentdhal_core.code_executor import CodeBlock + from agentdhal_extensions.code_executors.jupyter import JupyterCodeExecutor + + + async def main() -> None: + async with JupyterCodeExecutor() as executor: + cancel_token = CancellationToken() + code_blocks = [CodeBlock(code="print('hello world!')", language="python")] + code_result = await executor.execute_code_blocks(code_blocks, cancel_token) + print(code_result) + + + asyncio.run(main()) + + Example of using it with :class:`~agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool`: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_extensions.code_executors.jupyter import JupyterCodeExecutor + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.code_execution import PythonCodeExecutionTool + + + async def main() -> None: + async with JupyterCodeExecutor() as executor: + tool = PythonCodeExecutionTool(executor) + model_client = OpenAIChatCompletionClient(model="gpt-4o") + agent = AssistantAgent("assistant", model_client=model_client, tools=[tool]) + result = await agent.run(task="What is the 10th Fibonacci number? Use Python to calculate it.") + print(result) + + + asyncio.run(main()) + + Example of using it inside a :class:`~agentdhal_agentchat.agents._code_executor_agent.CodeExecutorAgent`: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import CodeExecutorAgent + from agentdhal_agentchat.messages import TextMessage + from agentdhal_extensions.code_executors.jupyter import JupyterCodeExecutor + from agentdhal_core import CancellationToken + + + async def main() -> None: + async with JupyterCodeExecutor() as executor: + code_executor_agent = CodeExecutorAgent("code_executor", code_executor=executor) + task = TextMessage( + content='''Here is some code + ```python + print('Hello world') + ``` + ''', + source="user", + ) + response = await code_executor_agent.on_messages([task], CancellationToken()) + print(response.chat_message) + + + asyncio.run(main()) + + + Args: + kernel_name (str): The kernel name to use. By default, "python3". + timeout (int): The timeout for code execution, by default 60. + output_dir (Path): The directory to save output files, by default a temporary directory. + + + .. note:: + Using the current directory (".") as output directory is deprecated. Using it will raise a deprecation warning. + """ + + component_config_schema = JupyterCodeExecutorConfig + component_provider_override = "agentdhal_extensions.code_executors.jupyter.JupyterCodeExecutor" + + def __init__( + self, + kernel_name: str = "python3", + timeout: int = 60, + output_dir: Optional[Union[Path, str]] = None, + ): + if timeout < 1: + raise ValueError("Timeout must be greater than or equal to 1.") + + self._output_dir: Path = Path(tempfile.mkdtemp()) if output_dir is None else Path(output_dir) + self._output_dir.mkdir(exist_ok=True, parents=True) + + self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None + self._temp_dir_path: Optional[Path] = None + + self._started = False + + self._kernel_name = kernel_name + self._timeout = timeout + + self._client: Optional[NotebookClient] = None + self.kernel_context: Optional[AbstractAsyncContextManager[None]] = None + + async def execute_code_blocks( + self, code_blocks: list[CodeBlock], cancellation_token: CancellationToken + ) -> JupyterCodeResult: + """Execute code blocks and return the result. + + Args: + code_blocks (list[CodeBlock]): The code blocks to execute. + + Returns: + JupyterCodeResult: The result of the code execution. + """ + outputs: list[str] = [] + output_files: list[Path] = [] + exit_code = 0 + + for code_block in code_blocks: + result = await self._execute_code_block(code_block, cancellation_token) + exit_code = result.exit_code + outputs.append(result.output) + output_files.extend(result.output_files) + + # Stop execution if one code block fails + if exit_code != 0: + break + + return JupyterCodeResult(exit_code=exit_code, output="\n".join(outputs), output_files=output_files) + + async def _execute_code_block( + self, code_block: CodeBlock, cancellation_token: CancellationToken + ) -> JupyterCodeResult: + """Execute single code block and return the result. + + Args: + code_block (CodeBlock): The code block to execute. + + Returns: + JupyterCodeResult: The result of the code execution. + """ + execute_task = asyncio.create_task( + self._execute_cell( + nbformat.new_code_cell(silence_pip(code_block.code, code_block.language)) # type: ignore + ) + ) + + cancellation_token.link_future(execute_task) + output_cell = await asyncio.wait_for(asyncio.shield(execute_task), timeout=self._timeout) + + outputs: list[str] = [] + output_files: list[Path] = [] + exit_code = 0 + + for output in output_cell.get("outputs", []): + match output.get("output_type"): + case "stream": + outputs.append(output.get("text", "")) + case "error": + traceback = re.sub(r"\x1b\[[0-9;]*[A-Za-z]", "", "\n".join(output["traceback"])) + outputs.append(traceback) + exit_code = 1 + case "execute_result" | "display_data": + data = output.get("data", {}) + for mime, content in data.items(): + match mime: + case "text/plain": + outputs.append(content) + case "image/png": + path = self._save_image(content) + output_files.append(path) + case "image/jpeg": + # TODO: Should this also be encoded? Images are encoded as both png and jpg + pass + case "text/html": + path = self._save_html(content) + output_files.append(path) + case _: + outputs.append(json.dumps(content)) + case _: + pass + + return JupyterCodeResult(exit_code=exit_code, output="\n".join(outputs), output_files=output_files) + + async def _execute_cell(self, cell: NotebookNode) -> NotebookNode: + # Temporary push cell to nb as async_execute_cell expects it. But then we want to remove it again as cells can take up significant amount of memory (especially with images) + if not self._client: + raise RuntimeError("Executor must be started before executing cells") + self._client.nb.cells.append(cell) + output = await self._client.async_execute_cell( + cell, + cell_index=0, + ) + self._client.nb.cells.pop() + return output + + def _save_image(self, image_data_base64: str) -> Path: + """Save image data to a file.""" + image_data = base64.b64decode(image_data_base64) + path = self._output_dir / f"{uuid.uuid4().hex}.png" + path.write_bytes(image_data) + return path.absolute() + + def _save_html(self, html_data: str) -> Path: + """Save HTML data to a file.""" + path = self._output_dir / f"{uuid.uuid4().hex}.html" + path.write_text(html_data) + return path.absolute() + + async def restart(self) -> None: + """Restart the code executor.""" + await self.stop() + await self.start() + + async def start(self) -> None: + """(Experimental) Start the code executor. + + Initializes the Jupyter Notebook execution environment by creating a new notebook and setting it up with the specified Jupyter Kernel. + Marks the executor as started, allowing for code execution. + This method should be called before executing any code blocks. + """ + if self._started: + return + + notebook: NotebookNode = nbformat.new_notebook() # type: ignore + + self._client = NotebookClient( + nb=notebook, + kernel_name=self._kernel_name, + timeout=self._timeout, + allow_errors=True, + ) + + self.kernel_context = self._client.async_setup_kernel() + await self.kernel_context.__aenter__() + + self._started = True + + async def stop(self) -> None: + """(Experimental) Stop the code executor. + + Terminates the Jupyter Notebook execution by exiting the kernel context and cleaning up the associated resources.""" + if not self._started: + return + + if self.kernel_context is not None: + await self.kernel_context.__aexit__(None, None, None) + self.kernel_context = None + + self._client = None + self._started = False + + def _to_config(self) -> JupyterCodeExecutorConfig: + """Convert current instance to config object""" + return JupyterCodeExecutorConfig( + kernel_name=self._kernel_name, timeout=self._timeout, output_dir=str(self.output_dir) + ) + + @property + def output_dir(self) -> Path: + # If a user specifies the current directory, warn them that this is deprecated + if self._output_dir == Path("."): + warnings.warn( + "Using the current directory as output_dir is deprecated", + DeprecationWarning, + stacklevel=2, + ) + return self._output_dir + + @classmethod + def _from_config(cls, config: JupyterCodeExecutorConfig) -> Self: + """Create instance from config object""" + return cls( + kernel_name=config.kernel_name, + timeout=config.timeout, + output_dir=Path(config.output_dir) if config.output_dir else None, + ) diff --git a/agent_dhal/agentdhal_extensions/code_executors/local/__init__.py b/agent_dhal/agentdhal_extensions/code_executors/local/__init__.py new file mode 100644 index 0000000..acf059c --- /dev/null +++ b/agent_dhal/agentdhal_extensions/code_executors/local/__init__.py @@ -0,0 +1,517 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/local_commandline_code_executor.py +# Credit to original authors + +import asyncio +import logging +import os +import sys +import tempfile +import warnings +from hashlib import sha256 +from pathlib import Path +from string import Template +from types import SimpleNamespace +from typing import Any, Callable, ClassVar, List, Optional, Sequence, Union + +from agentdhal_core import CancellationToken, Component +from agentdhal_core.code_executor import CodeBlock, CodeExecutor, FunctionWithRequirements, FunctionWithRequirementsStr +from pydantic import BaseModel +from typing_extensions import ParamSpec, Self + +from .._common import ( + PYTHON_VARIANTS, + CommandLineCodeResult, + build_python_functions_file, + get_file_name_from_content, + lang_to_cmd, + silence_pip, + to_stub, +) + +__all__ = ("LocalCommandLineCodeExecutor",) + +A = ParamSpec("A") + + +class LocalCommandLineCodeExecutorConfig(BaseModel): + """Configuration for LocalCommandLineCodeExecutor""" + + timeout: int = 60 + work_dir: Optional[str] = None + functions_module: str = "functions" + cleanup_temp_files: bool = True + + +class LocalCommandLineCodeExecutor(CodeExecutor, Component[LocalCommandLineCodeExecutorConfig]): + """A code executor class that executes code through a local command line + environment. + + .. danger:: + + This will execute code on the local machine. If being used with LLM generated code, caution should be used. + + Each code block is saved as a file and executed in a separate process in + the working directory, and a unique file is generated and saved in the + working directory for each code block. + The code blocks are executed in the order they are received. + Command line code is sanitized using regular expression match against a list of dangerous commands in order to prevent self-destructive + commands from being executed which may potentially affect the users environment. + Currently the only supported languages is Python and shell scripts. + For Python code, use the language "python" for the code block. + For shell scripts, use the language "bash", "shell", "sh", "pwsh", "powershell", or "ps1" for the code + block. + + .. note:: + + On Windows, the event loop policy must be set to `WindowsProactorEventLoopPolicy` to avoid issues with subprocesses. + + .. code-block:: python + + import sys + import asyncio + + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + + Args: + timeout (int): The timeout for the execution of any single code block. Default is 60. + work_dir (str): The working directory for the code execution. If None, + a default working directory will be used. The default working directory is a temporary directory. + functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list. + functions_module (str, optional): The name of the module that will be created to store the functions. Defaults to "functions". + cleanup_temp_files (bool, optional): Whether to automatically clean up temporary files after execution. Defaults to True. + virtual_env_context (Optional[SimpleNamespace], optional): The virtual environment context. Defaults to None. + + .. note:: + Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning. + + + Example: + + How to use `LocalCommandLineCodeExecutor` with a virtual environment different from the one used to run the autogen application: + Set up a virtual environment using the `venv` module, and pass its context to the initializer of `LocalCommandLineCodeExecutor`. This way, the executor will run code within the new environment. + + .. code-block:: python + + import venv + from pathlib import Path + import asyncio + + from agentdhal_core import CancellationToken + from agentdhal_core.code_executor import CodeBlock + from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor + + + async def example(): + work_dir = Path("coding") + work_dir.mkdir(exist_ok=True) + + venv_dir = work_dir / ".venv" + venv_builder = venv.EnvBuilder(with_pip=True) + venv_builder.create(venv_dir) + venv_context = venv_builder.ensure_directories(venv_dir) + + local_executor = LocalCommandLineCodeExecutor(work_dir=work_dir, virtual_env_context=venv_context) + await local_executor.execute_code_blocks( + code_blocks=[ + CodeBlock(language="bash", code="pip install matplotlib"), + ], + cancellation_token=CancellationToken(), + ) + + + asyncio.run(example()) + + """ + + component_config_schema = LocalCommandLineCodeExecutorConfig + component_provider_override = "agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor" + + SUPPORTED_LANGUAGES: ClassVar[List[str]] = [ + "bash", + "shell", + "sh", + "pwsh", + "powershell", + "ps1", + "python", + ] + FUNCTION_PROMPT_TEMPLATE: ClassVar[ + str + ] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names. + +For example, if there was a function called `foo` you could import it by writing `from $module_name import foo` + +$functions""" + + def __init__( + self, + timeout: int = 60, + work_dir: Optional[Union[Path, str]] = None, + functions: Sequence[ + Union[ + FunctionWithRequirements[Any, A], + Callable[..., Any], + FunctionWithRequirementsStr, + ] + ] = [], + functions_module: str = "functions", + cleanup_temp_files: bool = True, + virtual_env_context: Optional[SimpleNamespace] = None, + ): + if timeout < 1: + raise ValueError("Timeout must be greater than or equal to 1.") + self._timeout = timeout + + self._work_dir: Optional[Path] = None + if work_dir is not None: + # Check if user provided work_dir is the current directory and warn if so. + if Path(work_dir).resolve() == Path.cwd().resolve(): + warnings.warn( + "Using the current directory as work_dir is deprecated.", + DeprecationWarning, + stacklevel=2, + ) + if isinstance(work_dir, str): + self._work_dir = Path(work_dir) + else: + self._work_dir = work_dir + self._work_dir.mkdir(exist_ok=True) + + self._functions = functions + # Setup could take some time so we intentionally wait for the first code block to do it. + if len(functions) > 0: + self._setup_functions_complete = False + else: + self._setup_functions_complete = True + + if not functions_module.isidentifier(): + raise ValueError("Module name must be a valid Python identifier") + self._functions_module = functions_module + + self._cleanup_temp_files = cleanup_temp_files + self._virtual_env_context: Optional[SimpleNamespace] = virtual_env_context + + self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None + self._started = False + + # Check the current event loop policy if on windows. + if sys.platform == "win32": + current_policy = asyncio.get_event_loop_policy() + if hasattr(asyncio, "WindowsProactorEventLoopPolicy") and not isinstance( + current_policy, asyncio.WindowsProactorEventLoopPolicy + ): + warnings.warn( + "The current event loop policy is not WindowsProactorEventLoopPolicy. " + "This may cause issues with subprocesses. " + "Try setting the event loop policy to WindowsProactorEventLoopPolicy. " + "For example: `asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())`. " + "See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.ProactorEventLoop.", + stacklevel=2, + ) + + def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str: + """(Experimental) Format the functions for a prompt. + + The template includes two variables: + - `$module_name`: The module name. + - `$functions`: The functions formatted as stubs with two newlines between each function. + + Args: + prompt_template (str): The prompt template. Default is the class default. + + Returns: + str: The formatted prompt. + """ + + template = Template(prompt_template) + return template.substitute( + module_name=self._functions_module, + functions="\n\n".join([to_stub(func) for func in self._functions]), + ) + + @property + def timeout(self) -> int: + """(Experimental) The timeout for code execution.""" + return self._timeout + + @property + def work_dir(self) -> Path: + """(Experimental) The working directory for the code execution.""" + if self._work_dir is not None: + return self._work_dir + else: + # Automatically create temp directory if not exists + if self._temp_dir is None: + self._temp_dir = tempfile.TemporaryDirectory() + self._started = True + return Path(self._temp_dir.name) + + @property + def functions(self) -> List[str]: + raise NotImplementedError + + @property + def functions_module(self) -> str: + """(Experimental) The module name for the functions.""" + return self._functions_module + + @property + def cleanup_temp_files(self) -> bool: + """(Experimental) Whether to automatically clean up temporary files after execution.""" + return self._cleanup_temp_files + + async def _setup_functions(self, cancellation_token: CancellationToken) -> None: + func_file_content = build_python_functions_file(self._functions) + func_file = self.work_dir / f"{self._functions_module}.py" + func_file.write_text(func_file_content) + + # Collect requirements + lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)] + flattened_packages = [item for sublist in lists_of_packages for item in sublist] + required_packages = list(set(flattened_packages)) + if len(required_packages) > 0: + logging.info("Ensuring packages are installed in executor.") + + cmd_args = ["-m", "pip", "install"] + cmd_args.extend(required_packages) + + if self._virtual_env_context: + py_executable = self._virtual_env_context.env_exe + else: + py_executable = sys.executable + + task = asyncio.create_task( + asyncio.create_subprocess_exec( + py_executable, + *cmd_args, + cwd=self.work_dir, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + ) + cancellation_token.link_future(task) + try: + proc = await task + stdout, stderr = await asyncio.wait_for(proc.communicate(), self._timeout) + except asyncio.TimeoutError as e: + raise ValueError("Pip install timed out") from e + except asyncio.CancelledError as e: + raise ValueError("Pip install was cancelled") from e + + if proc.returncode is not None and proc.returncode != 0: + raise ValueError(f"Pip install failed. {stdout.decode()}, {stderr.decode()}") + + # Attempt to load the function file to check for syntax errors, imports etc. + exec_result = await self._execute_code_dont_check_setup( + [CodeBlock(code=func_file_content, language="python")], cancellation_token + ) + + if exec_result.exit_code != 0: + raise ValueError(f"Functions failed to load: {exec_result.output}") + + self._setup_functions_complete = True + + async def execute_code_blocks( + self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken + ) -> CommandLineCodeResult: + """(Experimental) Execute the code blocks and return the result. + + Args: + code_blocks (List[CodeBlock]): The code blocks to execute. + cancellation_token (CancellationToken): a token to cancel the operation + + Returns: + CommandLineCodeResult: The result of the code execution.""" + + if not self._setup_functions_complete: + await self._setup_functions(cancellation_token) + + return await self._execute_code_dont_check_setup(code_blocks, cancellation_token) + + async def _execute_code_dont_check_setup( + self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken + ) -> CommandLineCodeResult: + """ + Execute the provided code blocks in the local command line without re-checking setup. + Returns a CommandLineCodeResult indicating success or failure. + """ + logs_all: str = "" + file_names: List[Path] = [] + exitcode = 0 + + for code_block in code_blocks: + lang, code = code_block.language, code_block.code + lang = lang.lower() + + # Remove pip output where possible + code = silence_pip(code, lang) + + # Normalize python variants to "python" + if lang in PYTHON_VARIANTS: + lang = "python" + + # Abort if not supported + if lang not in self.SUPPORTED_LANGUAGES: + exitcode = 1 + logs_all += "\n" + f"unknown language {lang}" + break + + # Try extracting a filename (if present) + try: + filename = get_file_name_from_content(code, self.work_dir) + except ValueError: + return CommandLineCodeResult( + exit_code=1, + output="Filename is not in the workspace", + code_file=None, + ) + + # If no filename is found, create one + if filename is None: + code_hash = sha256(code.encode()).hexdigest() + if lang.startswith("python"): + ext = "py" + elif lang in ["pwsh", "powershell", "ps1"]: + ext = "ps1" + else: + ext = lang + + filename = f"tmp_code_{code_hash}.{ext}" + + written_file = (self.work_dir / filename).resolve() + with written_file.open("w", encoding="utf-8") as f: + f.write(code) + file_names.append(written_file) + + # Build environment + env = os.environ.copy() + if self._virtual_env_context: + virtual_env_bin_abs_path = os.path.abspath(self._virtual_env_context.bin_path) + env["PATH"] = f"{virtual_env_bin_abs_path}{os.pathsep}{env['PATH']}" + + # Decide how to invoke the script + if lang == "python": + program = ( + os.path.abspath(self._virtual_env_context.env_exe) if self._virtual_env_context else sys.executable + ) + extra_args = [str(written_file.absolute())] + else: + # Get the appropriate command for the language + program = lang_to_cmd(lang) + + # Special handling for PowerShell + if program == "pwsh": + extra_args = [ + "-NoProfile", + "-ExecutionPolicy", + "Bypass", + "-File", + str(written_file.absolute()), + ] + else: + # Shell commands (bash, sh, etc.) + extra_args = [str(written_file.absolute())] + + # Create a subprocess and run + task = asyncio.create_task( + asyncio.create_subprocess_exec( + program, + *extra_args, + cwd=self.work_dir, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + ) + cancellation_token.link_future(task) + + proc = None # Track the process + try: + proc = await task + stdout, stderr = await asyncio.wait_for(proc.communicate(), self._timeout) + exitcode = proc.returncode or 0 + except asyncio.TimeoutError: + logs_all += "\nTimeout" + exitcode = 124 + if proc: + proc.terminate() + await proc.wait() # Ensure process is fully dead + break + except asyncio.CancelledError: + logs_all += "\nCancelled" + exitcode = 125 + if proc: + proc.terminate() + await proc.wait() + break + + logs_all += stderr.decode() + logs_all += stdout.decode() + + if exitcode != 0: + break + + code_file = str(file_names[0]) if file_names else None + code_result = CommandLineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_file) + + if self._cleanup_temp_files: + for file in file_names: + try: + file.unlink(missing_ok=True) + except OSError as error: + logging.error(f"Failed to delete temporary file {file}: {error}") + + return code_result + + async def restart(self) -> None: + """(Experimental) Restart the code executor.""" + warnings.warn( + "Restarting local command line code executor is not supported. No action is taken.", + stacklevel=2, + ) + + async def start(self) -> None: + """(Experimental) Start the code executor. + + Initializes the local code executor and should be called before executing any code blocks. + It marks the executor internal state as started. + If no working directory is provided, the method creates a temporary directory for the executor to use. + """ + if self._work_dir is None and self._temp_dir is None: + self._temp_dir = tempfile.TemporaryDirectory() + self._started = True + + async def stop(self) -> None: + """(Experimental) Stop the code executor. + + Stops the local code executor and performs the cleanup of the temporary working directory (if it was created). + The executor's internal state is markes as no longer started. + """ + if self._temp_dir is not None: + self._temp_dir.cleanup() + self._temp_dir = None + self._started = False + pass + + def _to_config(self) -> LocalCommandLineCodeExecutorConfig: + if self._functions: + logging.info("Functions will not be included in serialized configuration") + if self._virtual_env_context: + logging.info("Virtual environment context will not be included in serialized configuration") + + return LocalCommandLineCodeExecutorConfig( + timeout=self._timeout, + work_dir=str(self.work_dir), + functions_module=self._functions_module, + cleanup_temp_files=self._cleanup_temp_files, + ) + + @classmethod + def _from_config(cls, config: LocalCommandLineCodeExecutorConfig) -> Self: + return cls( + timeout=config.timeout, + work_dir=Path(config.work_dir) if config.work_dir is not None else None, + functions_module=config.functions_module, + cleanup_temp_files=config.cleanup_temp_files, + ) diff --git a/agent_dhal/agentdhal_extensions/experimental/__init__.py b/agent_dhal/agentdhal_extensions/experimental/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/README.md b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/README.md new file mode 100644 index 0000000..d483054 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/README.md @@ -0,0 +1,210 @@ +# Task-Centric Memory +_(EXPERIMENTAL, RESEARCH IN PROGRESS)_ + +**Task-Centric Memory** is an active research project aimed at giving AI agents the ability to: + +* Accomplish general tasks more effectively by learning quickly and continually beyond context-window limitations. +* Remember guidance, corrections, plans, and demonstrations provided by users. +* Learn through the agent's own experience and adapt quickly to changing circumstances. +* Avoid repeating mistakes on tasks that are similar to those previously encountered. + +## Installation + +Install AutoGen and its extension package as follows: + +```bash +pip install -U "autogen-agentchat" "autogen-ext[openai]" "autogen-ext[task-centric-memory]" +``` + +## Quickstart + +

+ Description +

+ +This first code snippet runs a basic test to verify that the installation was successful, +as illustrated by the diagram to the right. + +```python +import asyncio +from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_ext.experimental.task_centric_memory import MemoryController +from autogen_ext.experimental.task_centric_memory.utils import PageLogger + + +async def main() -> None: + client = OpenAIChatCompletionClient(model="gpt-4o") + logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart"}) # Optional, but very useful. + memory_controller = MemoryController(reset=True, client=client, logger=logger) + + # Add a few task-insight pairs as memories, where an insight can be any string that may help solve the task. + await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color") + await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan") + await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite") + + # Retrieve memories for a new task that's related to only two of the stored memories. + memos = await memory_controller.retrieve_relevant_memos(task="What colors do I like most?") + print("{} memories retrieved".format(len(memos))) + for memo in memos: + print("- " + memo.insight) + + +asyncio.run(main()) +``` + +

+ Description +

+ +This second code example shows one way to incorporate task-centric memory directly into an AutoGen agent, +in this case a subclass of RoutedAgent. +To keep the code short, only the simplest form of memory retrieval is exercised by this agent. + +```python + +import asyncio +from dataclasses import dataclass +from typing import List + +from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler +from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage +from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_ext.experimental.task_centric_memory import MemoryController +from autogen_ext.experimental.task_centric_memory.utils import PageLogger + + +@dataclass +class Message: + content: str + + +class MemoryEnabledAgent(RoutedAgent): + def __init__( + self, description: str, model_client: ChatCompletionClient, memory_controller: MemoryController + ) -> None: + super().__init__(description) + self._model_client = model_client + self._memory_controller = memory_controller + + @message_handler + async def handle_message(self, message: Message, context: MessageContext) -> Message: + # Retrieve relevant memories for the task. + memos = await self._memory_controller.retrieve_relevant_memos(task=message.content) + + # Format the memories for the model. + formatted_memos = "Info that may be useful:\n" + "\n".join(["- " + memo.insight for memo in memos]) + print(f"{'-' * 23}Text appended to the user message{'-' * 24}\n{formatted_memos}\n{'-' * 80}") + + # Create the messages for the model with the retrieved memories. + messages: List[LLMMessage] = [ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content=message.content, source="user"), + UserMessage(content=formatted_memos, source="user"), + ] + + # Call the model with the messages. + model_result = await self._model_client.create(messages=messages) + assert isinstance(model_result.content, str) + + # Send the model's response to the user. + return Message(content=model_result.content) + + +async def main() -> None: + client = OpenAIChatCompletionClient(model="gpt-4o") + logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart2"}) # Optional, but very useful. + memory_controller = MemoryController(reset=True, client=client, logger=logger) + + # Prepopulate memory to mimic learning from a prior session. + await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color") + await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan") + await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite") + + # Create and start an agent runtime. + runtime = SingleThreadedAgentRuntime() + runtime.start() + + # Register the agent type. + await MemoryEnabledAgent.register( + runtime, + "memory_enabled_agent", + lambda: MemoryEnabledAgent( + "A agent with memory", model_client=client, memory_controller=memory_controller + ), + ) + + # Send a direct message to the agent. + request = "What colors do I like most?" + print("User request: " + request) + response = await runtime.send_message( + Message(content=request), AgentId("memory_enabled_agent", "default") + ) + print("Agent response: " + response.content) + + # Stop the agent runtime. + await runtime.stop() + + +asyncio.run(main()) +``` + +## Sample Code + +The example above modifies the agent's code. +But it's also possible to add task-centric memory to an agent or multi-agent team _without_ modifying any agent code. +See the [sample code](../../../../../../samples/task_centric_memory) for that and other forms of fast, memory-based learning. + + +## Architecture + +

+ Description +

+ +The block diagram to the right outlines the key components of the architecture in the most general form. +The memory components are shown in blue, and the green blocks represent external components. + +The **Memory Controller** implements the fast-learning methods described below, +and manages communication with a **Memory Bank** containing a vector DB and associated structures. + +The **Agent or Team** is the AI agent or team of agents to which memory is being added. +The sample code shows how to add task-centric memory to a simple AssistantAgent or a MagenticOneGroupChat team. + +The **Apprentice, app, or service** represents the code that instantiates the agent and memory controller, +and routes information between them, effectively wrapping agent and memory into a combined component. +The term _Apprentice_ connotes that this combination uses memory to learn quickly on the job. +The Apprentice class is a minimal reference implementation provided as utility code for illustration and testing, +but most applications will use their own code instead of the Apprentice. + +## Memory Creation and Storage + +Each stored memory (called a _memo_) contains a text insight and (optionally) a task description. +The insight is intended to help the agent accomplish future tasks that are similar to a prior task. +The memory controller provides methods for different types of learning. +If the user provides advice for solving a given task, the advice is extracted by the model client and stored as an insight. +If the user demonstrates how to perform a task, +the task and demonstration are stored together as an insight used to solve similar but different tasks. +If the agent is given a task (free of side-effects) and some means of determining success or failure, +the memory controller repeats the following learning loop in the background some number of times: + +1. Test the agent on the task a few times to check for a failure. +2. If a failure is found, analyze the agent's response in order to: + 1. Diagnose the failure of reasoning or missing information, + 2. Phrase a general piece of advice, such as what a teacher might give to a student, + 3. Temporarily append this advice to the task description, + 4. Return to step 1. + 5. If some piece of advice succeeds in helping the agent solve the task a number of times, add the advice as an insight to memory. +3. For each insight to be stored in memory, an LLM is prompted to generate a set of free-form, multi-word topics related to the insight. Each topic is embedded to a fixed-length vector and stored in a vector DB mapping it to the topic’s related insight. + +## Memory Retrieval and Usage + +The memory controller provides methods for different types of memory retrieval. +When the agent is given a task, the following steps are performed by the controller: +1. The task is rephrased into a generalized form. +2. A set of free-form, multi-word query topics are generated from the generalized task. +3. A potentially large number of previously stored topics, those most similar to each query topic, are retrieved from the vector DB along with the insights they map to. +4. These candidate memos are filtered by the aggregate similarity of their stored topics to the query topics. +5. In the final filtering stage, an LLM is prompted to validate only those insights that seem potentially useful in solving the task at hand. + +Retrieved insights that pass the filtering steps are listed under a heading like +"Important insights that may help solve tasks like this", then appended to the task description before it is passed to the agent as usual. diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/__init__.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/__init__.py new file mode 100644 index 0000000..97415af --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/__init__.py @@ -0,0 +1,4 @@ +from ._memory_bank import MemoryBankConfig +from .memory_controller import MemoryController, MemoryControllerConfig + +__all__ = ["MemoryController", "MemoryControllerConfig", "MemoryBankConfig"] diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/_memory_bank.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/_memory_bank.py new file mode 100644 index 0000000..62ba71b --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/_memory_bank.py @@ -0,0 +1,201 @@ +import os +import pickle +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, TypedDict + +from ._string_similarity_map import StringSimilarityMap +from .utils.page_logger import PageLogger + + +@dataclass +class Memo: + """ + Represents an atomic unit of memory that can be stored in a memory bank and later retrieved. + """ + + task: str | None # The task description, if any. + insight: str # A hint, solution, plan, or any other text that may help solve a similar task. + + +# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating +# the settings that change frequently, as when loading many settings from a single YAML file. +class MemoryBankConfig(TypedDict, total=False): + path: str + relevance_conversion_threshold: float + n_results: int + distance_threshold: int + + +class MemoryBank: + """ + Stores task-completion insights as memories in a vector DB for later retrieval. + + Args: + reset: True to clear the DB before starting. + config: An optional dict that can be used to override the following values: + + - path: The path to the directory where the memory bank files are stored. + - relevance_conversion_threshold: The threshold used to normalize relevance. + - n_results: The maximum number of most relevant results to return for any given topic. + - distance_threshold: The maximum string-pair distance for a memo to be retrieved. + + logger: An optional logger. If None, no logging will be performed. + """ + + def __init__( + self, + reset: bool, + config: MemoryBankConfig | None = None, + logger: PageLogger | None = None, + ) -> None: + if logger is None: + logger = PageLogger() # Nothing will be logged by this object. + self.logger = logger + self.logger.enter_function() + + # Apply default settings and any config overrides. + memory_dir_path = "./memory_bank/default" + self.relevance_conversion_threshold = 1.7 + self.n_results = 25 + self.distance_threshold = 100 + if config is not None: + memory_dir_path = config.get("path", memory_dir_path) + self.relevance_conversion_threshold = config.get( + "relevance_conversion_threshold", self.relevance_conversion_threshold + ) + self.n_results = config.get("n_results", self.n_results) + self.distance_threshold = config.get("distance_threshold", self.distance_threshold) + + memory_dir_path = os.path.expanduser(memory_dir_path) + self.logger.info("\nMEMORY BANK DIRECTORY {}".format(memory_dir_path)) + path_to_db_dir = os.path.join(memory_dir_path, "string_map") + self.path_to_dict = os.path.join(memory_dir_path, "uid_memo_dict.pkl") + + self.string_map = StringSimilarityMap(reset=reset, path_to_db_dir=path_to_db_dir, logger=self.logger) + + # Load or create the associated memo dict on disk. + self.uid_memo_dict: Dict[str, Memo] = {} + self.last_memo_id = 0 + if (not reset) and os.path.exists(self.path_to_dict): + self.logger.info("\nLOADING MEMOS FROM DISK at {}".format(self.path_to_dict)) + with open(self.path_to_dict, "rb") as f: + self.uid_memo_dict = pickle.load(f) + self.last_memo_id = len(self.uid_memo_dict) + self.logger.info("\n{} MEMOS LOADED".format(len(self.uid_memo_dict))) + + # Clear the DB if requested. + if reset: + self._reset_memos() + + self.logger.leave_function() + + def reset(self) -> None: + """ + Forces immediate deletion of all contents, in memory and on disk. + """ + self.string_map.reset_db() + self._reset_memos() + + def _reset_memos(self) -> None: + """ + Forces immediate deletion of the memos, in memory and on disk. + """ + self.logger.info("\nCLEARING MEMOS") + self.uid_memo_dict = {} + self.save_memos() + + def save_memos(self) -> None: + """ + Saves the current memo structures (possibly empty) to disk. + """ + self.string_map.save_string_pairs() + with open(self.path_to_dict, "wb") as file: + self.logger.info("\nSAVING MEMOS TO DISK at {}".format(self.path_to_dict)) + pickle.dump(self.uid_memo_dict, file) + + def contains_memos(self) -> bool: + """ + Returns True if the memory bank contains any memo. + """ + return len(self.uid_memo_dict) > 0 + + def _map_topics_to_memo(self, topics: List[str], memo_id: str, memo: Memo) -> None: + """ + Adds a mapping in the vec DB from each topic to the memo. + """ + self.logger.enter_function() + self.logger.info("\nINSIGHT\n{}".format(memo.insight)) + for topic in topics: + self.logger.info("\n TOPIC = {}".format(topic)) + self.string_map.add_input_output_pair(topic, memo_id) + self.uid_memo_dict[memo_id] = memo + self.save_memos() + self.logger.leave_function() + + def add_memo(self, insight_str: str, topics: List[str], task_str: Optional[str] = None) -> None: + """ + Adds an insight to the memory bank, given topics related to the insight, and optionally the task. + """ + self.logger.enter_function() + self.last_memo_id += 1 + id_str = str(self.last_memo_id) + insight = Memo(insight=insight_str, task=task_str) + self._map_topics_to_memo(topics, id_str, insight) + self.logger.leave_function() + + def add_task_with_solution(self, task: str, solution: str, topics: List[str]) -> None: + """ + Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight. + This is useful when the insight is a demonstration of how to solve a given type of task. + """ + self.logger.enter_function() + self.last_memo_id += 1 + id_str = str(self.last_memo_id) + # Prepend the insight to the task description for context. + insight_str = "Example task:\n\n{}\n\nExample solution:\n\n{}".format(task, solution) + memo = Memo(insight=insight_str, task=task) + self._map_topics_to_memo(topics, id_str, memo) + self.logger.leave_function() + + def get_relevant_memos(self, topics: List[str]) -> List[Memo]: + """ + Returns any memos from the memory bank that appear sufficiently relevant to the input topics. + """ + self.logger.enter_function() + + # Retrieve all topic matches, and gather them into a single list. + matches: List[Tuple[str, str, float]] = [] # Each match is a tuple: (topic, memo_id, distance) + for topic in topics: + matches.extend(self.string_map.get_related_string_pairs(topic, self.n_results, self.distance_threshold)) + + # Build a dict of memo-relevance pairs from the matches. + memo_relevance_dict: Dict[str, float] = {} + for match in matches: + relevance = self.relevance_conversion_threshold - match[2] + memo_id = match[1] + if memo_id in memo_relevance_dict: + memo_relevance_dict[memo_id] += relevance + else: + memo_relevance_dict[memo_id] = relevance + + # Log the details of all the retrieved memos. + self.logger.info("\n{} POTENTIALLY RELEVANT MEMOS".format(len(memo_relevance_dict))) + for memo_id, relevance in memo_relevance_dict.items(): + memo = self.uid_memo_dict[memo_id] + details = "" + if memo.task is not None: + details += "\n TASK: {}\n".format(memo.task) + details += "\n INSIGHT: {}\n\n RELEVANCE: {:.3f}\n".format(memo.insight, relevance) + self.logger.info(details) + + # Sort the memo-relevance pairs by relevance, in descending order. + memo_relevance_dict = dict(sorted(memo_relevance_dict.items(), key=lambda item: item[1], reverse=True)) + + # Compose the list of sufficiently relevant memos to return. + memo_list: List[Memo] = [] + for memo_id in memo_relevance_dict: + if memo_relevance_dict[memo_id] >= 0: + memo_list.append(self.uid_memo_dict[memo_id]) + + self.logger.leave_function() + return memo_list diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/_prompter.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/_prompter.py new file mode 100644 index 0000000..319f21b --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/_prompter.py @@ -0,0 +1,289 @@ +import time +from typing import List, Union + +from agentdhal_core import Image +from agentdhal_core.models import ( + AssistantMessage, + ChatCompletionClient, + CreateResult, + LLMMessage, + SystemMessage, + UserMessage, +) + +from .utils._functions import UserContent +from .utils.page_logger import PageLogger + + +class Prompter: + """ + Centralizes most of the Apprentice prompts sent to the model client. + + Args: + client: The client to call the model. + logger: An optional logger. If None, no logging will be performed. + """ + + def __init__(self, client: ChatCompletionClient, logger: PageLogger | None = None) -> None: + if logger is None: + logger = PageLogger() # Nothing will be logged by this object. + self.logger = logger + + self.client = client + self.default_system_message_content = "You are a helpful assistant." + self.time_spent_in_model_calls = 0.0 + self.num_model_calls = 0 + self.start_time = time.time() + + # Create the chat history + self._chat_history: List[LLMMessage] = [] + + async def call_model( + self, + summary: str, + user_content: UserContent, + system_message_content: str | None = None, + keep_these_messages: bool = True, + ) -> str: + """ + Calls the model client with the given input and returns the response. + """ + # Prepare the input message list + if system_message_content is None: + system_message_content = self.default_system_message_content + system_message: LLMMessage + if self.client.model_info["family"] == "o1": + # No system message allowed, so pass it as the first user message. + system_message = UserMessage(content=system_message_content, source="User") + else: + # System message allowed. + system_message = SystemMessage(content=system_message_content) + + user_message = UserMessage(content=user_content, source="User") + input_messages = [system_message] + self._chat_history + [user_message] + + # Double check the types of the input messages. + for message in input_messages: + for part in message.content: + assert isinstance(part, str) or isinstance(part, Image), "Invalid message content type: {}".format( + type(part) + ) + + # Call the model + start_time = time.time() + response = await self.client.create(input_messages) + assert isinstance(response, CreateResult) + response_string = response.content + assert isinstance(response_string, str) + response_message = AssistantMessage(content=response_string, source="Assistant") + assert isinstance(response_message, AssistantMessage) + self.time_spent_in_model_calls += time.time() - start_time + self.num_model_calls += 1 + + # Log the model call + self.logger.log_model_call(summary=summary, input_messages=input_messages, response=response) + + # Manage the chat history + if keep_these_messages: + self._chat_history.append(user_message) + self._chat_history.append(response_message) + + # Return the response as a string for now + return response_string + + def _clear_history(self) -> None: + """ + Empties the message list containing the chat history. + """ + self._chat_history = [] + + async def learn_from_failure( + self, task_description: str, memory_section: str, final_response: str, expected_answer: str, work_history: str + ) -> str: + """ + Tries to create an insight to help avoid the given failure in the future. + """ + sys_message = """- You are a patient and thorough teacher. +- Your job is to review work done by students and help them learn how to do better.""" + + user_message: List[Union[str, Image]] = [] + user_message.append("# A team of students made a mistake on the following task:\n") + user_message.extend([task_description]) + + if len(memory_section) > 0: + user_message.append(memory_section) + + user_message.append("# Here's the expected answer, which would have been correct:\n") + user_message.append(expected_answer) + + user_message.append("# Here is the students' answer, which was INCORRECT:\n") + user_message.append(final_response) + + user_message.append("# Please review the students' work which follows:\n") + user_message.append("**----- START OF STUDENTS' WORK -----**\n\n") + user_message.append(work_history) + user_message.append("\n**----- END OF STUDENTS' WORK -----**\n\n") + + user_message.append( + "# Now carefully review the students' work above, explaining in detail what the students did right and what they did wrong.\n" + ) + + self._clear_history() + await self.call_model( + summary="Ask the model to learn from this failure", + system_message_content=sys_message, + user_content=user_message, + ) + user_message = [ + "Now put yourself in the mind of the students. What misconception led them to their incorrect answer?" + ] + await self.call_model( + summary="Ask the model to state the misconception", + system_message_content=sys_message, + user_content=user_message, + ) + + user_message = [ + "Please express your key insights in the form of short, general advice that will be given to the students. Just one or two sentences, or they won't bother to read it." + ] + insight = await self.call_model( + summary="Ask the model to formulate a concise insight", + system_message_content=sys_message, + user_content=user_message, + ) + return insight + + async def find_index_topics(self, input_string: str) -> List[str]: + """ + Returns a list of topics related to the given string. + """ + sys_message = """You are an expert at semantic analysis.""" + + user_message: List[Union[str, Image]] = [] + user_message.append("""- My job is to create a thorough index for a book called Task Completion, and I need your help. +- Every paragraph in the book needs to be indexed by all the topics related to various kinds of tasks and strategies for completing them. +- Your job is to read the text below and extract the task-completion topics that are covered. +- The number of topics depends on the length and content of the text. But you should list at least one topic, and potentially many more. +- Each topic you list should be a meaningful phrase composed of a few words. Don't use whole sentences as topics. +- Don't include details that are unrelated to the general nature of the task, or a potential strategy for completing tasks. +- List each topic on a separate line, without any extra text like numbering, or bullets, or any other formatting, because we don't want those things in the index of the book.\n\n""") + + user_message.append("# Text to be indexed\n") + user_message.append(input_string) + + self._clear_history() + topics = await self.call_model( + summary="Ask the model to extract topics", system_message_content=sys_message, user_content=user_message + ) + + # Parse the topics into a list. + topic_list: List[str] = [] + for line in topics.split("\n"): + if len(line) > 0: + topic_list.append(line) + + return topic_list + + async def generalize_task(self, task_description: str, revise: bool | None = True) -> str: + """ + Attempts to rewrite a task description in a more general form. + """ + + sys_message = """You are a helpful and thoughtful assistant.""" + + user_message: List[Union[str, Image]] = [ + "We have been given a task description. Our job is not to complete the task, but merely rephrase the task in simpler, more general terms, if possible. Please reach through the following task description, then explain your understanding of the task in detail, as a single, flat list of all the important points." + ] + user_message.append("\n# Task description") + user_message.append(task_description) + + self._clear_history() + generalized_task = await self.call_model( + summary="Ask the model to rephrase the task in a list of important points", + system_message_content=sys_message, + user_content=user_message, + ) + + if revise: + user_message = [ + "Do you see any parts of this list that are irrelevant to actually solving the task? If so, explain which items are irrelevant." + ] + await self.call_model( + summary="Ask the model to identify irrelevant points", + system_message_content=sys_message, + user_content=user_message, + ) + + user_message = [ + "Revise your original list to include only the most general terms, those that are critical to solving the task, removing any themes or descriptions that are not essential to the solution. Your final list may be shorter, but do not leave out any part of the task that is needed for solving the task. Do not add any additional commentary either before or after the list." + ] + generalized_task = await self.call_model( + summary="Ask the model to make a final list of general terms", + system_message_content=sys_message, + user_content=user_message, + ) + + return generalized_task + + async def validate_insight(self, insight: str, task_description: str) -> bool: + """ + Judges whether the insight could help solve the task. + """ + + sys_message = """You are a helpful and thoughtful assistant.""" + + user_message: List[Union[str, Image]] = [ + """We have been given a potential insight that may or may not be useful for solving a given task. +- First review the following task. +- Then review the insight that follows, and consider whether it might help solve the given task. +- Do not attempt to actually solve the task. +- Reply with a single character, '1' if the insight may be useful, or '0' if it is not.""" + ] + user_message.append("\n# Task description") + user_message.append(task_description) + user_message.append("\n# Possibly useful insight") + user_message.append(insight) + self._clear_history() + response = await self.call_model( + summary="Ask the model to validate the insight", + system_message_content=sys_message, + user_content=user_message, + ) + return response == "1" + + async def extract_task(self, text: str) -> str | None: + """ + Returns a task found in the given text, or None if not found. + """ + sys_message = """You are a helpful and thoughtful assistant.""" + user_message: List[Union[str, Image]] = [ + """Does the following text contain a question or a some task we are being asked to perform? +- If so, please reply with the full question or task description, along with any supporting information, but without adding extra commentary or formatting. +- If the task is just to remember something, that doesn't count as a task, so don't include it. +- If there is no question or task in the text, simply write "None" with no punctuation.""" + ] + user_message.append("\n# Text to analyze") + user_message.append(text) + self._clear_history() + response = await self.call_model( + summary="Ask the model to extract a task", system_message_content=sys_message, user_content=user_message + ) + return response if response != "None" else None + + async def extract_advice(self, text: str) -> str | None: + """ + Returns advice from the given text, or None if not found. + """ + sys_message = """You are a helpful and thoughtful assistant.""" + user_message: List[Union[str, Image]] = [ + """Does the following text contain any information or advice that might be useful later? +- If so, please copy the information or advice, adding no extra commentary or formatting. +- If there is no potentially useful information or advice at all, simply write "None" with no punctuation.""" + ] + user_message.append("\n# Text to analyze") + user_message.append(text) + self._clear_history() + response = await self.call_model( + summary="Ask the model to extract advice", system_message_content=sys_message, user_content=user_message + ) + return response if response != "None" else None diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/_string_similarity_map.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/_string_similarity_map.py new file mode 100644 index 0000000..1510c41 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/_string_similarity_map.py @@ -0,0 +1,124 @@ +import os +import pickle +from typing import Dict, List, Tuple, Union + +import chromadb +from chromadb.api.types import ( + QueryResult, +) +from chromadb.config import Settings + +from .utils.page_logger import PageLogger + + +class StringSimilarityMap: + """ + Provides storage and similarity-based retrieval of string pairs using a vector database. + Each DB entry is a pair of strings: an input string and an output string. + The input string is embedded and used as the retrieval key. + The output string can be anything, but it's typically used as a dict key. + Vector embeddings are currently supplied by Chroma's default Sentence Transformers. + + Args: + - reset: True to clear the DB immediately after creation. + - path_to_db_dir: Path to the directory where the DB is stored. + - logger: An optional logger. If None, no logging will be performed. + """ + + def __init__(self, reset: bool, path_to_db_dir: str, logger: PageLogger | None = None) -> None: + if logger is None: + logger = PageLogger() # Nothing will be logged by this object. + self.logger = logger + self.path_to_db_dir = path_to_db_dir + + # Load or create the vector DB on disk. + chromadb_settings = Settings( + anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir + ) + self.db_client = chromadb.Client(chromadb_settings) + self.vec_db = self.db_client.create_collection("string-pairs", get_or_create=True) # The collection is the DB. + + # Load or create the associated string-pair dict on disk. + self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl") + self.uid_text_dict: Dict[str, Tuple[str, str]] = {} + self.last_string_pair_id = 0 + if (not reset) and os.path.exists(self.path_to_dict): + self.logger.debug("\nLOADING STRING SIMILARITY MAP FROM DISK at {}".format(self.path_to_dict)) + with open(self.path_to_dict, "rb") as f: + self.uid_text_dict = pickle.load(f) + self.last_string_pair_id = len(self.uid_text_dict) + if len(self.uid_text_dict) > 0: + self.logger.debug("\n{} STRING PAIRS LOADED".format(len(self.uid_text_dict))) + self._log_string_pairs() + + # Clear the DB if requested. + if reset: + self.reset_db() + + def _log_string_pairs(self) -> None: + """ + Logs all string pairs currently in the map. + """ + self.logger.debug("LIST OF STRING PAIRS") + for uid, text in self.uid_text_dict.items(): + input_text, output_text = text + self.logger.debug(" ID: {}\n INPUT TEXT: {}\n OUTPUT TEXT: {}".format(uid, input_text, output_text)) + + def save_string_pairs(self) -> None: + """ + Saves the string-pair dict (self.uid_text_dict) to disk. + """ + self.logger.debug("\nSAVING STRING SIMILARITY MAP TO DISK at {}".format(self.path_to_dict)) + with open(self.path_to_dict, "wb") as file: + pickle.dump(self.uid_text_dict, file) + + def reset_db(self) -> None: + """ + Forces immediate deletion of the DB's contents, in memory and on disk. + """ + self.logger.debug("\nCLEARING STRING-PAIR MAP") + self.db_client.delete_collection("string-pairs") + self.vec_db = self.db_client.create_collection("string-pairs") + self.uid_text_dict = {} + self.save_string_pairs() + + def add_input_output_pair(self, input_text: str, output_text: str) -> None: + """ + Adds one input-output string pair to the DB. + """ + self.last_string_pair_id += 1 + self.vec_db.add(documents=[input_text], ids=[str(self.last_string_pair_id)]) + self.uid_text_dict[str(self.last_string_pair_id)] = input_text, output_text + self.logger.debug( + "\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {}\n INPUT\n {}\n OUTPUT\n {}\n".format( + self.last_string_pair_id, input_text, output_text + ) + ) + # self._log_string_pairs() # For deeper debugging, uncomment to log all string pairs after each addition. + + def get_related_string_pairs( + self, query_text: str, n_results: int, threshold: Union[int, float] + ) -> List[Tuple[str, str, float]]: + """ + Retrieves up to n string pairs that are related to the given query text within the specified distance threshold. + """ + string_pairs_with_distances: List[Tuple[str, str, float]] = [] + if n_results > len(self.uid_text_dict): + n_results = len(self.uid_text_dict) + if n_results > 0: + results: QueryResult = self.vec_db.query(query_texts=[query_text], n_results=n_results) + num_results = len(results["ids"][0]) + for i in range(num_results): + uid = results["ids"][0][i] + input_text = results["documents"][0][i] if results["documents"] else "" + distance = results["distances"][0][i] if results["distances"] else 0.0 + if distance < threshold: + input_text_2, output_text = self.uid_text_dict[uid] + assert input_text == input_text_2 + self.logger.debug( + "\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format( + input_text, output_text, distance + ) + ) + string_pairs_with_distances.append((input_text, output_text, distance)) + return string_pairs_with_distances diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/memory_controller.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/memory_controller.py new file mode 100644 index 0000000..9915598 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/memory_controller.py @@ -0,0 +1,478 @@ +from typing import TYPE_CHECKING, Awaitable, Callable, List, Tuple, TypedDict + +from agentdhal_core.models import ( + ChatCompletionClient, +) + +from ._memory_bank import Memo, MemoryBank +from ._prompter import Prompter + +if TYPE_CHECKING: + from ._memory_bank import MemoryBankConfig +from .utils.grader import Grader +from .utils.page_logger import PageLogger + + +# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating +# the settings that change frequently, as when loading many settings from a single YAML file. +class MemoryControllerConfig(TypedDict, total=False): + generalize_task: bool + revise_generalized_task: bool + generate_topics: bool + validate_memos: bool + max_memos_to_retrieve: int + max_train_trials: int + max_test_trials: int + MemoryBank: "MemoryBankConfig" + + +class MemoryController: + """ + (EXPERIMENTAL, RESEARCH IN PROGRESS) + + Implements fast, memory-based learning, and manages the flow of information to and from a memory bank. + + Args: + reset: True to empty the memory bank before starting. + client: The model client to use internally. + task_assignment_callback: An optional callback used to assign a task to any agent managed by the caller. + config: An optional dict that can be used to override the following values: + + - generalize_task: Whether to rewrite tasks in more general terms. + - revise_generalized_task: Whether to critique then rewrite the generalized task. + - generate_topics: Whether to base retrieval directly on tasks, or on topics extracted from tasks. + - validate_memos: Whether to apply a final validation stage to retrieved memos. + - max_memos_to_retrieve: The maximum number of memos to return from retrieve_relevant_memos(). + - max_train_trials: The maximum number of learning iterations to attempt when training on a task. + - max_test_trials: The total number of attempts made when testing for failure on a task. + - MemoryBank: A config dict passed to MemoryBank. + + logger: An optional logger. If None, a default logger will be created. + + Example: + + The `task-centric-memory` extra first needs to be installed: + + .. code-block:: bash + + pip install "agentdhal-ext[task-centric-memory]" + + The following code snippet shows how to use this class for the most basic storage and retrieval of memories.: + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.experimental.task_centric_memory import MemoryController + from agentdhal_extensions.experimental.task_centric_memory.utils import PageLogger + + + async def main() -> None: + client = OpenAIChatCompletionClient(model="gpt-4o") + logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart"}) # Optional, but very useful. + memory_controller = MemoryController(reset=True, client=client, logger=logger) + + # Add a few task-insight pairs as memories, where an insight can be any string that may help solve the task. + await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color") + await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan") + await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite") + + # Retrieve memories for a new task that's related to only two of the stored memories. + memos = await memory_controller.retrieve_relevant_memos(task="What colors do I like most?") + print("{} memories retrieved".format(len(memos))) + for memo in memos: + print("- " + memo.insight) + + + asyncio.run(main()) + """ + + def __init__( + self, + reset: bool, + client: ChatCompletionClient, + task_assignment_callback: Callable[[str], Awaitable[Tuple[str, str]]] | None = None, + config: MemoryControllerConfig | None = None, + logger: PageLogger | None = None, + ) -> None: + if logger is None: + logger = PageLogger({"level": "DEBUG"}) + self.logger = logger + self.logger.enter_function() + + # Apply default settings and any config overrides. + self.generalize_task = True + self.revise_generalized_task = True + self.generate_topics = True + self.validate_memos = True + self.max_memos_to_retrieve = 10 + self.max_train_trials = 10 + self.max_test_trials = 3 + memory_bank_config = None + if config is not None: + self.generalize_task = config.get("generalize_task", self.generalize_task) + self.revise_generalized_task = config.get("revise_generalized_task", self.revise_generalized_task) + self.generate_topics = config.get("generate_topics", self.generate_topics) + self.validate_memos = config.get("validate_memos", self.validate_memos) + self.max_memos_to_retrieve = config.get("max_memos_to_retrieve", self.max_memos_to_retrieve) + self.max_train_trials = config.get("max_train_trials", self.max_train_trials) + self.max_test_trials = config.get("max_test_trials", self.max_test_trials) + memory_bank_config = config.get("MemoryBank", memory_bank_config) + + self.client = client + self.task_assignment_callback = task_assignment_callback + self.prompter = Prompter(client, logger) + self.memory_bank = MemoryBank(reset=reset, config=memory_bank_config, logger=logger) + self.grader = Grader(client, logger) + self.logger.leave_function() + + def reset_memory(self) -> None: + """ + Empties the memory bank in RAM and on disk. + """ + self.memory_bank.reset() + + async def train_on_task(self, task: str, expected_answer: str) -> None: + """ + Repeatedly assigns a task to the agent, and tries to learn from failures by creating useful insights as memories. + """ + self.logger.enter_function() + self.logger.info("Iterate on the task, possibly discovering a useful new insight.\n") + _, insight = await self._iterate_on_task(task, expected_answer) + if insight is None: + self.logger.info("No useful insight was discovered.\n") + else: + self.logger.info("A new insight was created:\n{}".format(insight)) + await self.add_memo(insight, task) + self.logger.leave_function() + + async def test_on_task(self, task: str, expected_answer: str, num_trials: int = 1) -> Tuple[str, int, int]: + """ + Assigns a task to the agent, along with any relevant memos retrieved from memory. + """ + self.logger.enter_function() + assert self.task_assignment_callback is not None + response = "" + num_successes = 0 + + for trial in range(num_trials): + self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1)) + task_plus_insights = task + + # Try to retrieve any relevant memories from the DB. + filtered_memos = await self.retrieve_relevant_memos(task) + filtered_insights = [memo.insight for memo in filtered_memos] + if len(filtered_insights) > 0: + self.logger.info("Relevant insights were retrieved from memory.\n") + memory_section = self._format_memory_section(filtered_insights) + if len(memory_section) > 0: + task_plus_insights = task + "\n\n" + memory_section + + # Attempt to solve the task. + self.logger.info("Try to solve the task.\n") + response, _ = await self.task_assignment_callback(task_plus_insights) + + # Check if the response is correct. + response_is_correct, extracted_answer = await self.grader.is_response_correct( + task, response, expected_answer + ) + self.logger.info("Extracted answer: {}".format(extracted_answer)) + if response_is_correct: + self.logger.info("Answer is CORRECT.\n") + num_successes += 1 + else: + self.logger.info("Answer is INCORRECT.\n") + + # Calculate the success rate as a percentage, rounded to the nearest whole number. + self.logger.info("\nSuccess rate: {}%\n".format(round((num_successes / num_trials) * 100))) + self.logger.leave_function() + return response, num_successes, num_trials + + async def add_memo(self, insight: str, task: None | str = None, index_on_both: bool = True) -> None: + """ + Adds one insight to the memory bank, using the task (if provided) as context. + """ + self.logger.enter_function() + + generalized_task = "" + if task is not None: + self.logger.info("\nGIVEN TASK:") + self.logger.info(task) + if self.generalize_task: + generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task) + else: + generalized_task = task + + self.logger.info("\nGIVEN INSIGHT:") + self.logger.info(insight) + + # Get a list of topics from the insight and the task (if provided). + if task is None: + text_to_index = insight + self.logger.info("\nTOPICS EXTRACTED FROM INSIGHT:") + else: + if index_on_both: + text_to_index = generalized_task.strip() + "\n(Hint: " + insight + ")" + self.logger.info("\nTOPICS EXTRACTED FROM TASK AND INSIGHT COMBINED:") + else: + text_to_index = task + self.logger.info("\nTOPICS EXTRACTED FROM TASK:") + + if self.generate_topics: + topics = await self.prompter.find_index_topics(text_to_index) + else: + topics = [text_to_index] + self.logger.info("\n".join(topics)) + self.logger.info("") + + # Add the insight to the memory bank. + self.memory_bank.add_memo(insight, topics, task) + self.logger.leave_function() + + async def add_task_solution_pair_to_memory(self, task: str, solution: str) -> None: + """ + Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight. + This is useful when the task-solution pair is an exemplar of solving a task related to some other task. + """ + self.logger.enter_function() + + self.logger.info("\nEXAMPLE TASK:") + self.logger.info(task) + + self.logger.info("\nEXAMPLE SOLUTION:") + self.logger.info(solution) + + # Get a list of topics from the task. + if self.generate_topics: + topics = await self.prompter.find_index_topics(task.strip()) + else: + topics = [task.strip()] + self.logger.info("\nTOPICS EXTRACTED FROM TASK:") + self.logger.info("\n".join(topics)) + self.logger.info("") + + # Add the task and solution (as a combined insight) to the memory bank. + self.memory_bank.add_task_with_solution(task=task, solution=solution, topics=topics) + self.logger.leave_function() + + async def retrieve_relevant_memos(self, task: str) -> List[Memo]: + """ + Retrieves any memos from memory that seem relevant to the task. + """ + self.logger.enter_function() + + if self.memory_bank.contains_memos(): + self.logger.info("\nCURRENT TASK:") + self.logger.info(task) + + # Get a list of topics from the generalized task. + if self.generalize_task: + generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task) + else: + generalized_task = task + if self.generate_topics: + task_topics = await self.prompter.find_index_topics(generalized_task) + else: + task_topics = [generalized_task] + self.logger.info("\nTOPICS EXTRACTED FROM TASK:") + self.logger.info("\n".join(task_topics)) + self.logger.info("") + + # Retrieve relevant memos from the memory bank. + memo_list = self.memory_bank.get_relevant_memos(topics=task_topics) + + # Apply a final validation stage to keep only the memos that the LLM concludes are sufficiently relevant. + validated_memos: List[Memo] = [] + for memo in memo_list: + if len(validated_memos) >= self.max_memos_to_retrieve: + break + if (not self.validate_memos) or await self.prompter.validate_insight(memo.insight, task): + validated_memos.append(memo) + + self.logger.info("\n{} VALIDATED MEMOS".format(len(validated_memos))) + for memo in validated_memos: + if memo.task is not None: + self.logger.info("\n TASK: {}".format(memo.task)) + self.logger.info("\n INSIGHT: {}".format(memo.insight)) + else: + self.logger.info("\nNO SUFFICIENTLY RELEVANT MEMOS WERE FOUND IN MEMORY") + validated_memos = [] + + self.logger.leave_function() + return validated_memos + + def _format_memory_section(self, memories: List[str]) -> str: + """ + Formats a list of memories as a section for appending to a task description. + """ + memory_section = "" + if len(memories) > 0: + memory_section = "## Important insights that may help solve tasks like this\n" + for mem in memories: + memory_section += "- " + mem + "\n" + return memory_section + + async def _test_for_failure( + self, task: str, task_plus_insights: str, expected_answer: str + ) -> Tuple[bool, str, str]: + """ + Attempts to solve the given task multiple times to find a failure case to learn from. + """ + self.logger.enter_function() + self.logger.info("\nTask description, including any insights: {}".format(task_plus_insights)) + self.logger.info("\nExpected answer: {}\n".format(expected_answer)) + + assert self.task_assignment_callback is not None + failure_found = False + response, work_history = "", "" + + for trial in range(self.max_test_trials): + self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1)) + + # Attempt to solve the task. + self.logger.info("Try to solve the task.") + response, work_history = await self.task_assignment_callback(task_plus_insights) + + response_is_correct, extracted_answer = await self.grader.is_response_correct( + task, response, expected_answer + ) + self.logger.info("Extracted answer: {}".format(extracted_answer)) + if response_is_correct: + self.logger.info("Answer is CORRECT.\n") + else: + self.logger.info("Answer is INCORRECT.\n Stop testing, and return the details of the failure.\n") + failure_found = True + break + + self.logger.leave_function() + return failure_found, response, work_history + + async def _iterate_on_task(self, task: str, expected_answer: str) -> Tuple[str, None | str]: + """ + Repeatedly assigns a task to the agent, and tries to learn from failures by creating useful insights as memories. + """ + self.logger.enter_function() + self.logger.info("\nTask description: {}".format(task)) + self.logger.info("\nExpected answer: {}\n".format(expected_answer)) + + final_response = "" + old_memos = await self.retrieve_relevant_memos(task) + old_insights = [memo.insight for memo in old_memos] + new_insights: List[str] = [] + last_insight = None + insight = None + successful_insight = None + + # Loop until success (or timeout) while learning from failures. + for trial in range(1, self.max_train_trials + 1): + self.logger.info("\n----- TRAIN TRIAL {} -----\n".format(trial)) + task_plus_insights = task + + # Add any new insights we've accumulated so far. + if last_insight is not None: + memory_section = self._format_memory_section(old_insights + [last_insight]) + else: + memory_section = self._format_memory_section(old_insights) + if len(memory_section) > 0: + task_plus_insights += "\n\n" + memory_section + + # Can we find a failure case to learn from? + failure_found, response, work_history = await self._test_for_failure( + task, task_plus_insights, expected_answer + ) + if not failure_found: + # No. Time to exit the loop. + self.logger.info("\nResponse is CORRECT.\n Stop looking for insights.\n") + # Was this the first trial? + if trial == 1: + # Yes. We should return the successful response, and no insight. + final_response = response + else: + # No. We learned a successful insight, which should be returned. + successful_insight = insight + break + + # Will we try again? + if trial == self.max_train_trials: + # No. We're out of training trials. + self.logger.info("\nNo more trials will be attempted.\n") + break + + # Try to learn from this failure. + self.logger.info("\nResponse is INCORRECT. Try to learn from this failure.\n") + insight = await self.prompter.learn_from_failure( + task, memory_section, response, expected_answer, work_history + ) + self.logger.info("\nInsight: {}\n".format(insight)) + new_insights.append(insight) + last_insight = insight + + # Return the answer from the last loop. + self.logger.info("\n{}\n".format(final_response)) + self.logger.leave_function() + return final_response, successful_insight + + async def _append_any_relevant_memories(self, task: str) -> str: + """ + Appends any relevant memories to the task description. + """ + self.logger.enter_function() + + filtered_memos = await self.retrieve_relevant_memos(task) + filtered_insights = [memo.insight for memo in filtered_memos] + if len(filtered_insights) > 0: + self.logger.info("Relevant insights were retrieved from memory.\n") + memory_section = self._format_memory_section(filtered_insights) + if len(memory_section) > 0: + task = task + "\n\n" + memory_section + + self.logger.leave_function() + return task + + async def assign_task(self, task: str, use_memory: bool = True, should_await: bool = True) -> str: + """ + Assigns a task to some agent through the task_assignment_callback, along with any relevant memories. + """ + self.logger.enter_function() + + assert self.task_assignment_callback is not None + + if use_memory: + task = await self._append_any_relevant_memories(task) + + # Attempt to solve the task. + self.logger.info("Try to solve the task.\n") + assert should_await + response, _ = await self.task_assignment_callback(task) + + self.logger.leave_function() + return response + + async def consider_memo_storage(self, text: str) -> str | None: + """ + Tries to extract any advice from the given text and add it to memory. + """ + self.logger.enter_function() + + advice = await self.prompter.extract_advice(text) + self.logger.info("Advice: {}".format(advice)) + if advice is not None: + await self.add_memo(insight=advice) + + self.logger.leave_function() + return advice + + async def handle_user_message(self, text: str, should_await: bool = True) -> str: + """ + Handles a user message by extracting any advice as an insight to be stored in memory, and then calling assign_task(). + """ + self.logger.enter_function() + + # Check for advice. + advice = await self.consider_memo_storage(text) + + # Assign the task through the task_assignment_callback, using memory only if no advice was just provided. + response = await self.assign_task(text, use_memory=(advice is None), should_await=should_await) + + self.logger.leave_function() + return response diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/__init__.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/__init__.py new file mode 100644 index 0000000..82bf516 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/__init__.py @@ -0,0 +1,15 @@ +from .apprentice import Apprentice, ApprenticeConfig +from .chat_completion_client_recorder import ChatCompletionClientRecorder +from .grader import Grader +from .page_logger import PageLogger, PageLoggerConfig +from .teachability import Teachability + +__all__ = [ + "Apprentice", + "ChatCompletionClientRecorder", + "Grader", + "PageLogger", + "Teachability", + "ApprenticeConfig", + "PageLoggerConfig", +] diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/_functions.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/_functions.py new file mode 100644 index 0000000..ce24636 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/_functions.py @@ -0,0 +1,96 @@ +import hashlib +import os +from typing import List, Tuple, Union + +from agentdhal_core import FunctionCall, Image +from agentdhal_core.models import FunctionExecutionResult + +# Convenience types +UserContent = Union[str, List[Union[str, Image]]] +AssistantContent = Union[str, List[FunctionCall]] +FunctionExecutionContent = List[FunctionExecutionResult] +SystemContent = str +MessageContent = UserContent | AssistantContent | SystemContent | FunctionExecutionContent + + +def message_content_to_str(message_content: MessageContent | None) -> str: + """ + Converts the message content to a string. + """ + if message_content is None: + return "" + elif isinstance(message_content, str): + return message_content + elif isinstance(message_content, List): + converted: List[str] = list() + for item in message_content: + if isinstance(item, str): + converted.append(item) + elif isinstance(item, Image): + converted.append("") + else: + converted.append(str(item).rstrip()) + return "\n".join(converted) + else: + raise AssertionError("Unexpected response type.") + + +def text_from_user_content(user_content: UserContent) -> str: + """ + Extracts just the text from the user content. + """ + if isinstance(user_content, str): + return user_content + elif isinstance(user_content, List): + text_list: List[str] = list() + for item in user_content: + if isinstance(item, str): + text_list.append(item.rstrip()) + return "\n\n".join(text_list) + else: + raise AssertionError("Unexpected response type.") + + +def single_image_from_user_content(user_content: UserContent) -> Union[Image, None]: + """ + Extracts a single image from the user content. + """ + image_to_return = None + if isinstance(user_content, str): + return None + elif isinstance(user_content, List): + for item in user_content: + if isinstance(item, Image): + assert image_to_return is None, "Only one image is currently allowed in the user content." + image_to_return = item + else: + raise AssertionError("Unexpected response type.") + return image_to_return + + +def hash_directory(directory: str, hash_algo: str = "sha256") -> Tuple[str, int, int]: + """Computes a hash representing the state of a directory, including its structure and file contents.""" + hash_func = hashlib.new(hash_algo) + + # Also count the number of files and sub-directories + num_files = 0 + num_subdirs = 0 + + for root, dirs, files in sorted(os.walk(directory)): # Ensure order for consistent hashing + num_files += len(files) + num_subdirs += len(dirs) + for dir_name in sorted(dirs): + hash_func.update(dir_name.encode()) # Hash directory names + + for file_name in sorted(files): + file_path = os.path.join(root, file_name) + hash_func.update(file_name.encode()) # Hash file names + + try: + with open(file_path, "rb") as f: + while chunk := f.read(4096): # Read in chunks + hash_func.update(chunk) + except Exception: + pass + + return hash_func.hexdigest(), num_files, num_subdirs diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/apprentice.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/apprentice.py new file mode 100644 index 0000000..a76ba39 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/apprentice.py @@ -0,0 +1,257 @@ +import random +import time +from typing import TYPE_CHECKING, Any, List, Sequence, Tuple, TypedDict + +from agentdhal_agentchat.agents import AssistantAgent +from agentdhal_agentchat.base import TaskResult +from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage +from agentdhal_core.models import ( + ChatCompletionClient, + LLMMessage, + SystemMessage, + UserMessage, +) + +from .page_logger import PageLogger + +if TYPE_CHECKING: + from ..memory_controller import MemoryControllerConfig + + +# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating +# the settings that change frequently, as when loading many settings from a single YAML file. +class ApprenticeConfig(TypedDict, total=False): + name_of_agent_or_team: str + disable_prefix_caching: bool + MemoryController: "MemoryControllerConfig" + + +class Apprentice: + """ + A minimal wrapper combining task-centric memory with an agent or team. + Applications may use the Apprentice class, or they may directly instantiate + and call the Memory Controller using this class as an example. + + Args: + client: The client to call the model. + config: An optional dict that can be used to override the following values: + + - name_of_agent_or_team: The name of the target agent or team for assigning tasks to. + - disable_prefix_caching: True to disable prefix caching by prepending random ints to the first message. + - MemoryController: A config dict passed to MemoryController. + + logger: An optional logger. If None, a default logger will be created. + """ + + def __init__( + self, + client: ChatCompletionClient, + config: ApprenticeConfig | None = None, + logger: PageLogger | None = None, + ) -> None: + if logger is None: + logger = PageLogger({"level": "DEBUG"}) + self.logger = logger + + # Apply default settings and any config overrides. + self.name_of_agent_or_team = "AssistantAgent" + self.disable_prefix_caching = False + memory_controller_config = None + if config is not None: + self.name_of_agent_or_team = config.get("name_of_agent_or_team", self.name_of_agent_or_team) + self.disable_prefix_caching = config.get("disable_prefix_caching", self.disable_prefix_caching) + memory_controller_config = config.get("MemoryController", memory_controller_config) + + self.client = client + if self.disable_prefix_caching: + self.rand = random.Random() + self.rand.seed(int(time.time() * 1000)) + + # Create the MemoryController, which creates the MemoryBank. + from ..memory_controller import MemoryController + + self.memory_controller = MemoryController( + reset=True, + client=self.client, + task_assignment_callback=self.assign_task_to_agent_or_team, + config=memory_controller_config, + logger=self.logger, + ) + + def reset_memory(self) -> None: + """ + Resets the memory bank. + """ + self.memory_controller.reset_memory() + + async def handle_user_message(self, text: str, should_await: bool = True) -> str: + """ + Handles a user message, extracting any advice and assigning a task to the agent. + """ + self.logger.enter_function() + + # Pass the user message through to the memory controller. + response = await self.memory_controller.handle_user_message(text, should_await) + + self.logger.leave_function() + return response + + async def add_task_solution_pair_to_memory(self, task: str, solution: str) -> None: + """ + Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight. + This is useful when the insight is a demonstration of how to solve a given type of task. + """ + self.logger.enter_function() + + # Pass the task and solution through to the memory controller. + await self.memory_controller.add_task_solution_pair_to_memory(task, solution) + + self.logger.leave_function() + + async def assign_task(self, task: str, use_memory: bool = True, should_await: bool = True) -> str: + """ + Assigns a task to the agent, along with any relevant insights/memories. + """ + self.logger.enter_function() + + # Pass the task through to the memory controller. + response = await self.memory_controller.assign_task(task, use_memory, should_await) + + self.logger.leave_function() + return response + + async def train_on_task(self, task: str, expected_answer: str) -> None: + """ + Repeatedly assigns a task to the completion agent, and tries to learn from failures by creating useful insights as memories. + """ + self.logger.enter_function() + + # Pass the task through to the memory controller. + await self.memory_controller.train_on_task(task, expected_answer) + + self.logger.leave_function() + + async def assign_task_to_agent_or_team(self, task: str) -> Tuple[str, str]: + """ + Passes the given task to the target agent or team. + """ + self.logger.enter_function() + + # Pass the task through. + if self.name_of_agent_or_team == "MagenticOneGroupChat": + response, work_history = await self._assign_task_to_magentic_one(task) + elif self.name_of_agent_or_team == "AssistantAgent": + response, work_history = await self._assign_task_to_assistant_agent(task) + else: + raise AssertionError("Invalid base agent") + + self.logger.leave_function() + return response, work_history + + async def _assign_task_to_assistant_agent(self, task: str) -> Tuple[Any, Any]: + """ + Passes the given task to a newly created AssistantAgent with a generic 6-step system prompt. + """ + self.logger.enter_function() + self.logger.info(task) + + system_message_content = """You are a helpful and thoughtful assistant. +In responding to every user message, you follow the same multi-step process given here: +1. Explain your understanding of the user message in detail, covering all the important points. +2. List as many possible responses as you can think of. +3. Carefully list and weigh the pros and cons (if any) of each possible response. +4. Critique the pros and cons above, looking for any flaws in your reasoning. But don't make up flaws that don't exist. +5. Decide on the best response, looping back to step 1 if none of the responses are satisfactory. +6. Finish by providing your final response in the particular format requested by the user.""" + + if self.disable_prefix_caching: + # Prepend a random int to disable prefix caching. + random_str = "({})\n\n".format(self.rand.randint(0, 1000000)) + system_message_content = random_str + system_message_content + + system_message: LLMMessage + if self.client.model_info["family"] == "o1": + # No system message allowed, so pass it as the first user message. + system_message = UserMessage(content=system_message_content, source="User") + else: + # System message allowed. + system_message = SystemMessage(content=system_message_content) + + user_message: LLMMessage = UserMessage(content=task, source="User") + system_message_list: List[LLMMessage] = [system_message] + user_message_list: List[LLMMessage] = [user_message] + input_messages: List[LLMMessage] = system_message_list + user_message_list + + assistant_agent = AssistantAgent( + "assistant_agent", + self.client, + system_message=system_message_content, + ) + + # Get the agent's response to the task. + task_result: TaskResult = await assistant_agent.run(task=TextMessage(content=task, source="User")) + messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages + message: BaseAgentEvent | BaseChatMessage = messages[-1] + response_str = message.to_text() + + # Log the model call + self.logger.log_model_task( + summary="Ask the model to complete the task", input_messages=input_messages, task_result=task_result + ) + self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str)) + + # Use the response as the work history as well. + work_history = response_str + + self.logger.leave_function() + return response_str, work_history + + async def _assign_task_to_magentic_one(self, task: str) -> Tuple[str, str]: + """ + Instantiates a MagenticOneGroupChat team, and passes the given task to it. + """ + self.logger.enter_function() + self.logger.info(task) + + general_agent = AssistantAgent( + "general_agent", + self.client, + description="A general GPT-4o AI assistant capable of performing a variety of tasks.", + ) + + from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer + + web_surfer = MultimodalWebSurfer( + name="web_surfer", + model_client=self.client, + downloads_folder="logs", + debug_dir="logs", + to_save_screenshots=True, + ) + + from agentdhal_agentchat.teams import MagenticOneGroupChat + + team = MagenticOneGroupChat( + [general_agent, web_surfer], + model_client=self.client, + max_turns=20, + ) + + # Get the team's response to the task. + task_result: TaskResult = await team.run(task=task) + + assert isinstance(task_result, TaskResult) + messages = task_result.messages + + response_str_list: List[str] = [] + for message in messages: + response_str_list.append(message.to_text()) + response_str = "\n".join(response_str_list) + + self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str)) + + # MagenticOne's response is the chat history, which we use here as the work history. + work_history = response_str + + self.logger.leave_function() + return response_str, work_history diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/chat_completion_client_recorder.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/chat_completion_client_recorder.py new file mode 100644 index 0000000..b75932b --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/chat_completion_client_recorder.py @@ -0,0 +1,227 @@ +import json +import os +import warnings +from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, TypedDict, Union + +from agentdhal_core import CancellationToken +from agentdhal_core.models import ( + ChatCompletionClient, + CreateResult, + LLMMessage, + ModelCapabilities, # type: ignore + ModelInfo, + RequestUsage, +) +from agentdhal_core.tools import Tool, ToolSchema +from pydantic import BaseModel + +from .page_logger import PageLogger + + +class RecordDict(TypedDict): + mode: Literal["create", "create_stream"] + messages: List[Mapping[str, Any]] + response: Dict[str, Any] + stream: List[Mapping[str, Any]] + + +class ChatCompletionClientRecorder(ChatCompletionClient): + """ + A chat completion client that supports fast, large-scale tests of code calling LLM clients. + + Two modes are supported: + + 1. "record": delegates to the underlying client while also recording the input messages and responses, + which are saved to disk when finalize() is called. + 2. "replay": loads previously recorded message and responses from disk, then on each call + checks that its message matches the recorded message, and returns the recorded response. + + The recorded data is stored as a JSON list of records. Each record is a dictionary with a "mode" + field (either "create" or "create_stream"), a serialized list of messages, and either a "response" (for + create calls) or a "stream" (a list of streamed outputs for create_stream calls). + + ReplayChatCompletionClient and ChatCompletionCache do similar things, but with significant differences: + + - ReplayChatCompletionClient replays pre-defined responses in a specified order without recording anything or checking the messages sent to the client. + - ChatCompletionCache caches responses and replays them for messages that have been seen before, regardless of order, and calls the base client for any uncached messages. + """ + + def __init__( + self, + client: ChatCompletionClient, + mode: Literal["record", "replay"], + session_file_path: str, + logger: PageLogger | None = None, + ) -> None: + if logger is None: + self.logger = PageLogger() # Disabled by default. + else: + self.logger = logger + self.logger.enter_function() + self.logger.info("Wrapping the base client in ChatCompletionClientRecorder.") + + self.base_client = client + self.mode = mode + self.session_file_path = os.path.expanduser(session_file_path) + self.records: List[RecordDict] = [] + self._record_index = 0 + self._num_checked_records = 0 + if self.mode == "record": + # Prepare to record the messages and responses. + self.logger.info("Recording mode enabled.\nRecording session to: " + self.session_file_path) + elif self.mode == "replay": + # Load the previously recorded messages and responses from disk. + self.logger.info("Replay mode enabled.\nRetrieving session from: " + self.session_file_path) + try: + with open(self.session_file_path, "r") as f: + self.records = json.load(f) + except Exception as e: + error_str = f"\nFailed to load recorded session: '{self.session_file_path}': {e}" + self.logger.error(error_str) + raise ValueError(error_str) from e + + self.logger.leave_function() + + async def create( + self, + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool | type[BaseModel]] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + tool_choice: Tool | Literal["auto", "required", "none"] = "auto", + ) -> CreateResult: + current_messages: List[Mapping[str, Any]] = [msg.model_dump() for msg in messages] + if self.mode == "record": + response = await self.base_client.create( + messages, + tools=tools, + json_output=json_output, + tool_choice=tool_choice, + extra_create_args=extra_create_args, + cancellation_token=cancellation_token, + ) + + rec: RecordDict = { + "mode": "create", + "messages": current_messages, + "response": response.model_dump(), + "stream": [], + } + self.records.append(rec) + return response + elif self.mode == "replay": + if self._record_index >= len(self.records): + error_str = "\nNo more recorded turns to check." + self.logger.error(error_str) + raise ValueError(error_str) + rec = self.records[self._record_index] + if rec.get("mode") != "create": + error_str = f"\nRecorded call type mismatch at index {self._record_index}: expected 'create', got '{rec.get('mode')}'." + self.logger.error(error_str) + raise ValueError(error_str) + recorded_messages = rec.get("messages") + if recorded_messages != current_messages: + error_str = ( + "\nCurrent message list doesn't match the recorded message list. See the pagelogs for details." + ) + assert recorded_messages is not None + self.logger.log_dict_list(recorded_messages, "recorded message list") + assert current_messages is not None + self.logger.log_dict_list(current_messages, "current message list") + self.logger.error(error_str) + raise ValueError(error_str) + self._record_index += 1 + self._num_checked_records += 1 + + data = rec.get("response") + # Populate a CreateResult from the data. + assert data is not None + result = CreateResult( + content=data.get("content", ""), + finish_reason=data.get("finish_reason", "stop"), + usage=data.get("usage", RequestUsage(prompt_tokens=0, completion_tokens=0)), + cached=True, + ) + return result + + else: + error_str = f"\nUnknown mode: {self.mode}" + self.logger.error(error_str) + raise ValueError(error_str) + + def create_stream( + self, + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool | type[BaseModel]] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + tool_choice: Tool | Literal["auto", "required", "none"] = "auto", + ) -> AsyncGenerator[Union[str, CreateResult], None]: + return self.base_client.create_stream( + messages, + tools=tools, + tool_choice=tool_choice, + json_output=json_output, + extra_create_args=extra_create_args, + cancellation_token=cancellation_token, + ) + + async def close(self) -> None: + await self.base_client.close() + + def actual_usage(self) -> RequestUsage: + # Calls base_client.actual_usage() and returns the result. + return self.base_client.actual_usage() + + def total_usage(self) -> RequestUsage: + # Calls base_client.total_usage() and returns the result. + return self.base_client.total_usage() + + def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: + # Calls base_client.count_tokens() and returns the result. + return self.base_client.count_tokens(messages, tools=tools) + + def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: + # Calls base_client.remaining_tokens() and returns the result. + return self.base_client.remaining_tokens(messages, tools=tools) + + @property + def capabilities(self) -> ModelCapabilities: # type: ignore + # Calls base_client.capabilities and returns the result. + warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2) + return self.base_client.capabilities + + @property + def model_info(self) -> ModelInfo: + # Calls base_client.model_info and returns the result. + return self.base_client.model_info + + def finalize(self) -> None: + """ + In record mode, saves the accumulated records to disk. + In replay mode, makes sure all the records were checked. + """ + self.logger.enter_function() + if self.mode == "record": + try: + # Create the directory if it doesn't exist. + os.makedirs(os.path.dirname(self.session_file_path), exist_ok=True) + # Write the records to disk. + with open(self.session_file_path, "w") as f: + json.dump(self.records, f, indent=2) + self.logger.info("\nRecorded session was saved to: " + self.session_file_path) + except Exception as e: + error_str = f"Failed to write records to '{self.session_file_path}': {e}" + self.logger.error(error_str) + raise ValueError(error_str) from e + elif self.mode == "replay": + if self._num_checked_records < len(self.records): + error_str = f"\nEarly termination. Only {self._num_checked_records} of the {len(self.records)} recorded turns were checked." + self.logger.error(error_str) + raise ValueError(error_str) + self.logger.info("\nRecorded session was fully replayed and checked.") + self.logger.leave_function() diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/grader.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/grader.py new file mode 100644 index 0000000..c26bc2d --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/grader.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Tuple, Union + +from agentdhal_core import Image +from agentdhal_core.models import ( + AssistantMessage, + ChatCompletionClient, + CreateResult, + LLMMessage, + SystemMessage, + UserMessage, +) + +from ._functions import UserContent +from .page_logger import PageLogger + +if TYPE_CHECKING: + from .apprentice import Apprentice + + +class Grader: + """ + Runs basic tests, and determines task success without limitation to string matches. + + Args: + client: The client to call the model. + logger: An optional logger. If None, no logging will be performed. + """ + + def __init__(self, client: ChatCompletionClient, logger: PageLogger | None = None) -> None: + if logger is None: + logger = PageLogger() # Nothing will be logged by this object. + self.logger = logger + self.client = client + + # Create the chat history + self._chat_history: List[LLMMessage] = [] + + async def test_apprentice( + self, + apprentice: Apprentice, + task_description: str, + expected_answer: str, + num_trials: int, + use_memory: bool, + client: ChatCompletionClient, + ) -> Tuple[int, int]: + self.logger.enter_function() + + self.logger.info("Testing the apprentice on the given task.\n") + + num_successes = 0 + + for trial in range(num_trials): + self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1)) + self.logger.info("Try to solve the task.\n") + response = await apprentice.assign_task(task_description, use_memory=use_memory) + response_is_correct, extracted_answer = await self.is_response_correct( + task_description, response, expected_answer + ) + self.logger.info("Extracted answer: {}".format(extracted_answer)) + if response_is_correct: + self.logger.info("Answer is CORRECT.\n") + num_successes += 1 + else: + self.logger.info("Answer is INCORRECT.\n") + + self.logger.info("\nSuccess rate: {}%\n".format(round((num_successes / num_trials) * 100))) + self.logger.leave_function() + return num_successes, num_trials + + async def call_model( + self, + summary: str, + user_content: UserContent, + system_message_content: str | None = None, + keep_these_messages: bool = True, + ) -> str: + """ + Calls the model client with the given input and returns the response. + """ + # Prepare the input message list + if system_message_content is None: + system_message_content = "You are a helpful assistant." + system_message: LLMMessage + if self.client.model_info["family"] == "o1": + # No system message allowed, so pass it as the first user message. + system_message = UserMessage(content=system_message_content, source="User") + else: + # System message allowed. + system_message = SystemMessage(content=system_message_content) + user_message = UserMessage(content=user_content, source="User") + input_messages = [system_message] + self._chat_history + [user_message] + + # Call the model. + response = await self.client.create(input_messages) + assert isinstance(response, CreateResult) + response_string = response.content + assert isinstance(response_string, str) + response_message = AssistantMessage(content=response_string, source="Assistant") + assert isinstance(response_message, AssistantMessage) + + # Log the model call + self.logger.log_model_call(summary=summary, input_messages=input_messages, response=response) + + # Manage the chat history + if keep_these_messages: + self._chat_history.append(user_message) + self._chat_history.append(response_message) + + # Return the response as a string + return response_string + + def _clear_history(self) -> None: + """ + Empties the message list containing the chat history. + """ + self._chat_history = [] + + async def is_response_correct( + self, task_description: str, response_to_be_graded: str, correct_answer: str + ) -> Tuple[bool, str]: + """ + Determines whether the response is equivalent to the task's correct answer. + """ + self.logger.enter_function() + + sys_message = """You are a helpful and thoughtful assistant.""" + + # Ask the model to extract the answer from the response. + user_message: List[Union[str, Image]] = [] + user_message.append("""Your job is to extract a possible answer to the following question from the given text. +- First review the following task. +- Then review the text that follows, which may an answer, plus reasoning that led to the answer. +- Do not attempt to actually solve the task yourself. +- Don't try to judge whether the reasoning steps were correct. +- Simply respond by summarizing the answer described in the text, omitting any other parts of the text. +- If no answer is present can be extracted from the text, simply reply "None".""") + user_message.append("\n# Task description") + user_message.append(task_description) + user_message.append("\n# Text that may contain an answer") + user_message.append(response_to_be_graded) + user_message_arg: UserContent = user_message + self._clear_history() + extracted_answer = await self.call_model( + summary="Ask the model to extract the answer", + system_message_content=sys_message, + user_content=user_message_arg, + ) + self.logger.info("Extracted answer: " + extracted_answer) + + # Ask the model to check the answer for correctness. + user_message = [ + """Your job is to decide whether a given answer to a task is correct or not. +- You will be given the task description and the correct, gold-standard answer, along with the answer to be graded. +- In general, an answer is correct if it is equivalent to the correct answer. +- Specifically, the given answer must contain the important information from the correct answer, and must not in any way contradict the correct answer. +- Ignore any differences of grammar, spelling mistakes, punctuation, capitalization, formatting, or extra commentary. +- An answer should be considered correct if it omits information that is clearly inferred. + - For instance, if the correct answer is "Paris, France", the answer "Paris" should be considered correct. +- Respond with a single character: '1' if the answer to be graded is correct", '0' if not.""" + ] + user_message.append("\n# Task description") + user_message.append(task_description) + user_message.append("\n# Correct answer") + user_message.append(correct_answer) + user_message.append("\n# Answer to be graded") + user_message.append(extracted_answer) + self._clear_history() + decision = await self.call_model( + summary="Ask the model to check the answer for correctness", + system_message_content=sys_message, + user_content=user_message, + ) + self.logger.info("Decision: " + decision) + + self.logger.leave_function() + return decision == "1", extracted_answer diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/page_logger.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/page_logger.py new file mode 100644 index 0000000..4cf020d --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/page_logger.py @@ -0,0 +1,546 @@ +import inspect +import json +import os +import shutil +from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict + +from agentdhal_agentchat.base import TaskResult +from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage +from agentdhal_core import Image +from agentdhal_core.models import ( + AssistantMessage, + CreateResult, + FunctionExecutionResultMessage, + LLMMessage, + RequestUsage, + SystemMessage, + UserMessage, +) + +from ._functions import MessageContent, hash_directory + + +def _html_opening(file_title: str, finished: bool = False) -> str: + """ + Returns the opening text of a simple HTML file. + """ + refresh_tag = '' if not finished else "" + st = f""" + + + + {refresh_tag} + {file_title} + + + """ + return st + + +def _html_closing() -> str: + """ + Return the closing text of a simple HTML file. + """ + return """""" + + +# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating +# the settings that change frequently, as when loading many settings from a single YAML file. +class PageLoggerConfig(TypedDict, total=False): + level: str + path: str + + +class PageLogger: + """ + Logs text and images to a set of HTML pages, one per function/method, linked to each other in a call tree. + + Args: + config: An optional dict that can be used to override the following values: + + - level: The logging level, one of DEBUG, INFO, WARNING, ERROR, CRITICAL, or NONE. + - path: The path to the directory where the log files will be written. + """ + + def __init__(self, config: PageLoggerConfig | None = None) -> None: + self.levels = { + "DEBUG": 10, + "INFO": 20, + "WARNING": 30, + "ERROR": 40, + "CRITICAL": 50, + "NONE": 100, + } + + # Apply default settings and any config overrides. + level_str = "NONE" # Default to no logging at all. + self.log_dir = "./pagelogs/default" + if config is not None: + level_str = config.get("level", level_str) + self.log_dir = config.get("path", self.log_dir) + self.level = self.levels[level_str] + self.log_dir = os.path.expanduser(self.log_dir) + + # If the logging level is set to NONE or higher, don't log anything. + if self.level >= self.levels["NONE"]: + return + + self.page_stack = PageStack() + self.pages: List[Page] = [] + self.last_page_id = 0 + self.name = "0 Call Tree" + self._create_run_dir() + self.flush() + self.finalized = False + + def __del__(self) -> None: + self.finalize() + + def finalize(self) -> None: + # Writes a hash of the log directory to a file for change detection. + if self.level >= self.levels["NONE"]: + return + + # Don't finalize the log if it has already been finalized. + if self.finalized: + return + + # Do nothing if the app is being forced to exit early. + if self.page_stack.size() > 0: + return + + self.flush(finished=True) + + # Write the hash and other details to a file. + hash_str, num_files, num_subdirs = hash_directory(self.log_dir) + hash_path = os.path.join(self.log_dir, "hash.txt") + with open(hash_path, "w") as f: + f.write(hash_str) + f.write("\n") + f.write("{} files\n".format(num_files)) + f.write("{} subdirectories\n".format(num_subdirs)) + + self.finalized = True + + @staticmethod + def _decorate_text(text: str, color: str, weight: str = "bold", demarcate: bool = False) -> str: + """ + Returns a string of text with HTML styling for weight and color. + """ + if demarcate: + text = f"<<<<< {text} >>>>>" + return f'{text}' + + @staticmethod + def _link_to_image(image_path: str, description: str) -> str: + """ + Returns an HTML string defining a thumbnail link to an image. + """ + # To avoid a bug in heml rendering aht displays underscores to the left of thumbnails, + # define the following string on a single line. + link = f"""{description}""" + return link + + def _get_next_page_id(self) -> int: + """Returns the next page id and increments the counter.""" + self.last_page_id += 1 + return self.last_page_id + + def _create_run_dir(self) -> None: + """Creates a fresh log directory.""" + if os.path.exists(self.log_dir): + shutil.rmtree(self.log_dir) + os.makedirs(self.log_dir) + + def _add_page(self, summary: str, show_in_call_tree: bool = True, finished: bool = True) -> "Page": + """ + Adds a new page to the log. + """ + page = Page( + page_logger=self, + index=self._get_next_page_id(), + summary=summary, + indent_level=len(self.page_stack.stack), + show_in_call_tree=show_in_call_tree, + finished=finished, + ) + self.pages.append(page) + self.flush() + if len(self.page_stack.stack) > 0: + # Insert a link to the new page into the calling page. + self.info("\n" + page.full_link) + return page + + def _log_text(self, text: str) -> None: + """ + Adds text to the current page. + """ + page = self.page_stack.top() + if page is not None: + page.add_lines(text, flush=True) + + def debug(self, line: str) -> None: + """ + Adds DEBUG text to the current page if debugging level <= DEBUG. + """ + if self.level <= self.levels["DEBUG"]: + self._log_text(line) + + def info(self, line: str) -> None: + """ + Adds INFO text to the current page if debugging level <= INFO. + """ + if self.level <= self.levels["INFO"]: + self._log_text(line) + + def warning(self, line: str) -> None: + """ + Adds WARNING text to the current page if debugging level <= WARNING. + """ + if self.level <= self.levels["WARNING"]: + self._log_text(line) + + def error(self, line: str) -> None: + """ + Adds ERROR text to the current page if debugging level <= ERROR. + """ + if self.level <= self.levels["ERROR"]: + self._log_text(line) + + def critical(self, line: str) -> None: + """ + Adds CRITICAL text to the current page if debugging level <= CRITICAL. + """ + if self.level <= self.levels["CRITICAL"]: + self._log_text(line) + + def _message_source(self, message: LLMMessage) -> str: + """ + Returns a decorated string indicating the source of a message. + """ + source = "UNKNOWN" + color = "black" + if isinstance(message, SystemMessage): + source = "SYSTEM" + color = "purple" + elif isinstance(message, UserMessage): + source = "USER" + color = "blue" + elif isinstance(message, AssistantMessage): + source = "ASSISTANT" + color = "green" + elif isinstance(message, FunctionExecutionResultMessage): + source = "FUNCTION" + color = "red" + return self._decorate_text(source, color, demarcate=True) + + def _format_message_content(self, message_content: MessageContent) -> str: + """ + Formats the message content for logging. + """ + # Start by converting the message content to a list of strings. + content_list: List[str] = [] + content = message_content + if isinstance(content, str): + content_list.append(content) + elif isinstance(content, list): + for item in content: + if isinstance(item, str): + content_list.append(item.rstrip()) + elif isinstance(item, Image): + # Save the image to disk. + image_filename = str(self._get_next_page_id()) + " image.jpg" + image_path = os.path.join(self.log_dir, image_filename) + item.image.save(image_path) + # Add a link to the image. + content_list.append(self._link_to_image(image_filename, "message_image")) + elif isinstance(item, Dict): + # Add a dictionary to the log. + json_str = json.dumps(item, indent=4) + content_list.append(json_str) + else: + content_list.append(str(item).rstrip()) + else: + content_list.append("") + + # Convert the list of strings to a single string containing newline separators. + output = "" + for item in content_list: + output += f"\n{item}\n" + return output + + def log_message_content(self, message_content: MessageContent, summary: str) -> None: + """ + Adds a page containing the message's content, including any images. + """ + if self.level > self.levels["INFO"]: + return None + page = self._add_page(summary=summary, show_in_call_tree=False) + self.page_stack.write_stack_to_page(page) + page.add_lines(self._format_message_content(message_content=message_content)) + page.flush() + + def log_dict_list(self, content: List[Mapping[str, Any]], summary: str) -> None: + """ + Adds a page containing a list of dicts. + """ + if self.level > self.levels["INFO"]: + return None + page = self._add_page(summary=summary, show_in_call_tree=False) + self.page_stack.write_stack_to_page(page) + + for item in content: + json_str = json.dumps(item, indent=4) + page.add_lines(json_str) + + page.flush() + + def _log_model_messages( + self, summary: str, input_messages: List[LLMMessage], response_str: str, usage: RequestUsage | None + ) -> Optional["Page"]: + """ + Adds a page containing the messages to a model (including any input images) and its response. + """ + page = self._add_page(summary=summary, show_in_call_tree=False) + self.page_stack.write_stack_to_page(page) + + if usage is not None: + page.add_lines("{} prompt tokens".format(usage.prompt_tokens)) + page.add_lines("{} completion tokens".format(usage.completion_tokens)) + for m in input_messages: + page.add_lines("\n" + self._message_source(m)) + page.add_lines(self._format_message_content(message_content=m.content)) + page.add_lines("\n" + self._decorate_text("ASSISTANT RESPONSE", "green", demarcate=True)) + page.add_lines("\n" + response_str + "\n") + page.flush() + return page + + def log_model_call( + self, summary: str, input_messages: List[LLMMessage], response: CreateResult + ) -> Optional["Page"]: + """ + Logs messages sent to a model and the TaskResult response to a new page. + """ + if self.level > self.levels["INFO"]: + return None + + response_str = response.content + if not isinstance(response_str, str): + response_str = "??" + + page = self._log_model_messages(summary, input_messages, response_str, response.usage) + return page + + def log_model_task( + self, summary: str, input_messages: List[LLMMessage], task_result: TaskResult + ) -> Optional["Page"]: + """ + Logs messages sent to a model and the TaskResult response to a new page. + """ + if self.level > self.levels["INFO"]: + return None + + messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages + message = messages[-1] + response_str = message.to_text() + if not isinstance(response_str, str): + response_str = "??" + + if hasattr(message, "models_usage"): + usage: RequestUsage | None = message.models_usage + else: + usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + + page = self._log_model_messages(summary, input_messages, response_str, usage) + return page + + def log_link_to_local_file(self, file_path: str) -> str: + """ + Returns a link to a local file in the log. + """ + file_name = os.path.basename(file_path) + link = f'{file_name}' + return link + + def add_link_to_image(self, description: str, source_image_path: str) -> None: + """ + Inserts a thumbnail link to an image to the page. + """ + # Remove every character from the string 'description' that is not alphanumeric or a space. + description = "".join(e for e in description if e.isalnum() or e.isspace()) + target_image_filename = str(self._get_next_page_id()) + " - " + description + # Copy the image to the log directory. + local_image_path = os.path.join(self.log_dir, target_image_filename) + shutil.copyfile(source_image_path, local_image_path) + self._log_text("\n" + description) + self._log_text(self._link_to_image(target_image_filename, description)) + + def flush(self, finished: bool = False) -> None: + """ + Writes the current state of the log to disk. + """ + if self.level > self.levels["INFO"]: + return + # Create a call tree of the log. + call_tree_path = os.path.join(self.log_dir, self.name + ".html") + with open(call_tree_path, "w") as f: + f.write(_html_opening("0 Call Tree", finished=finished)) + f.write(f"

{self.name}

") + f.write("\n") + for page in self.pages: + if page.show_in_call_tree: + f.write(page.line_text + "\n") + f.write("\n") + f.write(_html_closing()) + + def enter_function(self) -> Optional["Page"]: + """ + Adds a new page corresponding to the current function call. + """ + if self.level > self.levels["INFO"]: + return None + + page = None + frame_type = inspect.currentframe() + if frame_type is not None: + frame = frame_type.f_back # Get the calling frame + if frame is not None: + # Check if it's a method by looking for 'self' or 'cls' in f_locals + if "self" in frame.f_locals: + class_name = type(frame.f_locals["self"]).__name__ + elif "cls" in frame.f_locals: + class_name = frame.f_locals["cls"].__name__ + else: + class_name = None # Not part of a class + + if class_name is None: # Not part of a class + caller_name = frame.f_code.co_name + else: + caller_name = class_name + "." + frame.f_code.co_name + + # Create a new page for this function. + page = self._add_page(summary=caller_name, show_in_call_tree=True, finished=False) + self.page_stack.push(page) + self.page_stack.write_stack_to_page(page) + + page.add_lines("\nENTER {}".format(caller_name), flush=True) + return page + + def leave_function(self) -> None: + """ + Finishes the page corresponding to the current function call. + """ + if self.level > self.levels["INFO"]: + return None + page = self.page_stack.top() + if page is not None: + page.finished = True + page.add_lines("\nLEAVE {}".format(page.summary), flush=True) + self.page_stack.pop() + + +class Page: + """ + Represents a single HTML page in the logger output. + + Args: + page_logger: The PageLogger object that created this page. + index: The index of the page. + summary: A brief summary of the page's contents for display. + indent_level: The level of indentation in the call tree. + show_in_call_tree: Whether to display the page in the call tree. + finished: Whether the page is complete. + """ + + def __init__( + self, + page_logger: PageLogger, + index: int, + summary: str, + indent_level: int, + show_in_call_tree: bool = True, + finished: bool = True, + ): + """ + Initializes and writes to a new HTML page. + """ + self.page_logger = page_logger + self.index_str = str(index) + self.summary = summary + self.indent_level = indent_level + self.show_in_call_tree = show_in_call_tree + self.finished = finished + self.file_title = self.index_str + " " + self.summary + self.indentation_text = "| " * self.indent_level + self.full_link = f'{self.file_title}' + self.line_text = self.indentation_text + self.full_link + self.lines: List[str] = [] + self.flush() + + def add_lines(self, lines: str, flush: bool = False) -> None: + """ + Adds one or more lines to the page. + """ + lines_to_add: List[str] = [] + if "\n" in lines: + lines_to_add = lines.split("\n") + else: + lines_to_add.append(lines) + self.lines.extend(lines_to_add) + if flush: + self.flush() + + def flush(self) -> None: + """ + Writes the HTML page to disk. + """ + page_path = os.path.join(self.page_logger.log_dir, self.index_str + ".html") + with open(page_path, "w") as f: + f.write(_html_opening(self.file_title, finished=self.finished)) + f.write(f"

{self.file_title}

\n") + for line in self.lines: + try: + f.write(f"{line}\n") + except UnicodeEncodeError: + f.write("UnicodeEncodeError in this line.\n") + f.write(_html_closing()) + f.flush() + + +class PageStack: + """ + A call stack containing a list of currently active function pages in the order they called each other. + """ + + def __init__(self) -> None: + self.stack: List[Page] = [] + + def push(self, page: Page) -> None: + """Adds a page to the top of the stack.""" + self.stack.append(page) + + def pop(self) -> Page: + """Removes and returns the top page from the stack""" + return self.stack.pop() + + def size(self) -> int: + """Returns the number of pages in the stack.""" + return len(self.stack) + + def top(self) -> Page | None: + """Returns the top page from the stack without removing it""" + if self.size() == 0: + return None + return self.stack[-1] + + def write_stack_to_page(self, page: Page) -> None: + # Logs a properly indented string displaying the current call stack. + page.add_lines("\nCALL STACK") + for stack_page in self.stack: + page.add_lines(stack_page.line_text) + page.add_lines("") + page.add_lines("") + page.flush() diff --git a/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/teachability.py b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/teachability.py new file mode 100644 index 0000000..1de8747 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/experimental/task_centric_memory/utils/teachability.py @@ -0,0 +1,133 @@ +from typing import TYPE_CHECKING, Any + +from agentdhal_core import CancellationToken, Image +from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult +from agentdhal_core.model_context import ChatCompletionContext +from agentdhal_core.models import UserMessage + +if TYPE_CHECKING: + from agentdhal_extensions.experimental.task_centric_memory import MemoryController + + +class Teachability(Memory): + """ + Gives an AssistantAgent the ability to learn quickly from user teachings, hints, and advice. + + Steps for usage: + + 1. Instantiate MemoryController. + 2. Instantiate Teachability, passing the memory controller as a parameter. + 3. Instantiate an AssistantAgent, passing the teachability instance (wrapped in a list) as the memory parameter. + 4. Use the AssistantAgent as usual, such as for chatting with the user. + """ + + def __init__(self, memory_controller: "MemoryController", name: str | None = None) -> None: + """Initialize Teachability.""" + self._memory_controller = memory_controller + self._logger = memory_controller.logger + self._name = name or "teachability" + + @property + def name(self) -> str: + """Get the memory instance identifier.""" + return self._name + + def _extract_text(self, content_item: str | MemoryContent) -> str: + """Extract searchable text from content.""" + if isinstance(content_item, str): + return content_item + + content = content_item.content + mime_type = content_item.mime_type + + if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]: + return str(content) + elif mime_type == MemoryMimeType.JSON: + if isinstance(content, dict): + # Store original JSON string representation + return str(content).lower() + raise ValueError("JSON content must be a dict") + elif isinstance(content, Image): + raise ValueError("Image content cannot be converted to text") + else: + raise ValueError(f"Unsupported content type: {mime_type}") + + async def update_context( + self, + model_context: ChatCompletionContext, + ) -> UpdateContextResult: + """ + Extracts any advice from the last user turn to be stored in memory, + and adds any relevant memories to the model context. + """ + self._logger.enter_function() + + # Extract text from the user's last message + messages = await model_context.get_messages() + if not messages: + self._logger.leave_function() + return UpdateContextResult(memories=MemoryQueryResult(results=[])) + last_message = messages[-1] + last_user_text = last_message.content if isinstance(last_message.content, str) else str(last_message) + + # Add any relevant memories to the chat history + query_results = await self.query(last_user_text) + if query_results.results: + memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)] + memory_context = "\nPotentially relevant memories:\n" + "\n".join(memory_strings) + await model_context.add_message(UserMessage(content=memory_context, source="user")) + + # Add any user advice to memory + await self._memory_controller.consider_memo_storage(last_user_text) + + self._logger.leave_function() + return UpdateContextResult(memories=query_results) + + async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: + """ + Tries to extract any advice from the passed content and add it to memory. + """ + self._logger.enter_function() + + # Extract text from the incoming content + text = self._extract_text(content) + + # Check for advice to add to memory for later turns. + await self._memory_controller.consider_memo_storage(text) + + self._logger.leave_function() + + async def query( + self, + query: str | MemoryContent, + cancellation_token: CancellationToken | None = None, + **kwargs: Any, + ) -> MemoryQueryResult: + """ + Returns any memories that seem relevant to the query. + """ + self._logger.enter_function() + + task = self._extract_text(query) + memory_results: list[MemoryContent] = [] + filtered_memos = await self._memory_controller.retrieve_relevant_memos(task=task) + filtered_insights = [memo.insight for memo in filtered_memos] + for insight in filtered_insights: + self._logger.info(f"Insight: {insight}") + memory_content = MemoryContent( + content=insight, + mime_type="MemoryMimeType.TEXT", + metadata={}, + ) + memory_results.append(memory_content) + + self._logger.leave_function() + return MemoryQueryResult(results=memory_results) + + async def clear(self) -> None: + """Clear all entries from memory.""" + self._memory_controller.reset_memory() + + async def close(self) -> None: + """Clean up memory resources.""" + pass # No cleanup needed for this memory implementation diff --git a/agent_dhal/agentdhal_extensions/memory/__init__.py b/agent_dhal/agentdhal_extensions/memory/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/__init__.py @@ -0,0 +1 @@ + diff --git a/agent_dhal/agentdhal_extensions/memory/canvas/__init__.py b/agent_dhal/agentdhal_extensions/memory/canvas/__init__.py new file mode 100644 index 0000000..ad10924 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/canvas/__init__.py @@ -0,0 +1,4 @@ +from ._text_canvas import TextCanvas +from ._text_canvas_memory import TextCanvasMemory + +__all__ = ["TextCanvas", "TextCanvasMemory"] diff --git a/agent_dhal/agentdhal_extensions/memory/canvas/_canvas.py b/agent_dhal/agentdhal_extensions/memory/canvas/_canvas.py new file mode 100644 index 0000000..de2eca2 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/canvas/_canvas.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Union + + +class BaseCanvas(ABC): + """ + An abstract protocol for "canvas" objects that maintain + revision history for file-like data. Concrete subclasses + can handle text, images, structured data, etc. + + .. warning:: + + This is an experimental API and may change in the future. + + """ + + @abstractmethod + def list_files(self) -> Dict[str, int]: + """ + Returns a dict of filename -> latest revision number. + """ + raise NotImplementedError + + @abstractmethod + def get_latest_content(self, filename: str) -> Union[str, bytes, Any]: + """ + Returns the latest version of a file's content. + """ + raise NotImplementedError + + @abstractmethod + def add_or_update_file(self, filename: str, new_content: Union[str, bytes, Any]) -> None: + """ + Creates or updates the file content with a new revision. + """ + raise NotImplementedError + + @abstractmethod + def get_diff(self, filename: str, from_revision: int, to_revision: int) -> str: + """ + Returns a diff (in some format) between two revisions. + """ + raise NotImplementedError + + @abstractmethod + def apply_patch(self, filename: str, patch_data: Union[str, bytes, Any]) -> None: + """ + Applies a patch/diff to the latest revision and increments the revision. + """ + raise NotImplementedError diff --git a/agent_dhal/agentdhal_extensions/memory/canvas/_canvas_writer.py b/agent_dhal/agentdhal_extensions/memory/canvas/_canvas_writer.py new file mode 100644 index 0000000..e751b05 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/canvas/_canvas_writer.py @@ -0,0 +1,64 @@ +from agentdhal_core import CancellationToken +from agentdhal_core.tools import BaseTool +from pydantic import BaseModel + +from ._text_canvas import TextCanvas + + +class UpdateFileArgs(BaseModel): + filename: str + new_content: str + + +class UpdateFileResult(BaseModel): + status: str + + +class UpdateFileTool(BaseTool[UpdateFileArgs, UpdateFileResult]): + """ + Overwrites or creates a file in the canvas. + """ + + def __init__(self, canvas: TextCanvas): + super().__init__( + args_type=UpdateFileArgs, + return_type=UpdateFileResult, + name="update_file", + description="Create/update a file on the canvas with the provided content.", + ) + self._canvas = canvas + + async def run(self, args: UpdateFileArgs, cancellation_token: CancellationToken) -> UpdateFileResult: + self._canvas.add_or_update_file(args.filename, args.new_content) + return UpdateFileResult(status="OK") + + +class ApplyPatchArgs(BaseModel): + filename: str + patch_text: str + + +class ApplyPatchResult(BaseModel): + status: str + + +class ApplyPatchTool(BaseTool[ApplyPatchArgs, ApplyPatchResult]): + """ + Applies a unified diff patch to the given file on the canvas. + """ + + def __init__(self, canvas: TextCanvas): + super().__init__( + args_type=ApplyPatchArgs, + return_type=ApplyPatchResult, + name="apply_patch", + description=( + "Apply a unified diff patch to an existing file on the canvas. " + "The patch must be in diff/patch format. The file must exist or be created first." + ), + ) + self._canvas = canvas + + async def run(self, args: ApplyPatchArgs, cancellation_token: CancellationToken) -> ApplyPatchResult: + self._canvas.apply_patch(args.filename, args.patch_text) + return ApplyPatchResult(status="PATCH APPLIED") diff --git a/agent_dhal/agentdhal_extensions/memory/canvas/_text_canvas.py b/agent_dhal/agentdhal_extensions/memory/canvas/_text_canvas.py new file mode 100644 index 0000000..306a070 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/canvas/_text_canvas.py @@ -0,0 +1,192 @@ +import difflib +from typing import Any, Dict, List, Union + +try: # pragma: no cover + from unidiff import PatchSet +except ModuleNotFoundError: # pragma: no cover + PatchSet = None # type: ignore + +from ._canvas import BaseCanvas + + +class FileRevision: + """Tracks the history of one file's content.""" + + __slots__ = ("content", "revision") + + def __init__(self, content: str, revision: int) -> None: + self.content: str = content + self.revision: int = revision # e.g. an integer, a timestamp, or git hash + + +class TextCanvas(BaseCanvas): + """An in‑memory canvas that stores *text* files with full revision history. + + .. warning:: + + This is an experimental API and may change in the future. + + Besides the original CRUD‑like operations, this enhanced implementation adds: + + * **apply_patch** – applies patches using the ``unidiff`` library for accurate + hunk application and context line validation. + * **get_revision_content** – random access to any historical revision. + * **get_revision_diffs** – obtain the list of diffs applied between every + consecutive pair of revisions so that a caller can replay or audit the + full change history. + """ + + # ---------------------------------------------------------------------------------- + # Construction helpers + # ---------------------------------------------------------------------------------- + + def __init__(self) -> None: + # For each file we keep an *ordered* list of FileRevision where the last + # element is the most recent. Using a list keeps the memory footprint + # small and preserves order without any extra bookkeeping. + self._files: Dict[str, List[FileRevision]] = {} + + # ---------------------------------------------------------------------------------- + # Internal utilities + # ---------------------------------------------------------------------------------- + + def _latest_idx(self, filename: str) -> int: + """Return the index (not revision number) of the newest revision.""" + return len(self._files.get(filename, [])) - 1 + + def _ensure_file(self, filename: str) -> None: + if filename not in self._files: + raise ValueError(f"File '{filename}' does not exist on the canvas; create it first.") + + # ---------------------------------------------------------------------------------- + # Revision inspection helpers + # ---------------------------------------------------------------------------------- + + def get_revision_content(self, filename: str, revision: int) -> str: # NEW 🚀 + """Return the exact content stored in *revision*. + + If the revision does not exist an empty string is returned so that + downstream code can handle the "not found" case without exceptions. + """ + for rev in self._files.get(filename, []): + if rev.revision == revision: + return rev.content + return "" + + def get_revision_diffs(self, filename: str) -> List[str]: # NEW 🚀 + """Return a *chronological* list of unified‑diffs for *filename*. + + Each element in the returned list represents the diff that transformed + revision *n* into revision *n+1* (starting at revision 1 → 2). + """ + revisions = self._files.get(filename, []) + diffs: List[str] = [] + for i in range(1, len(revisions)): + older, newer = revisions[i - 1], revisions[i] + diff = difflib.unified_diff( + older.content.splitlines(keepends=True), + newer.content.splitlines(keepends=True), + fromfile=f"{filename}@r{older.revision}", + tofile=f"{filename}@r{newer.revision}", + ) + diffs.append("".join(diff)) + return diffs + + # ---------------------------------------------------------------------------------- + # BaseCanvas interface implementation + # ---------------------------------------------------------------------------------- + + def list_files(self) -> Dict[str, int]: + """Return a mapping of *filename → latest revision number*.""" + return {fname: revs[-1].revision for fname, revs in self._files.items() if revs} + + def get_latest_content(self, filename: str) -> str: # noqa: D401 – keep API identical + """Return the most recent content or an empty string if the file is new.""" + revs = self._files.get(filename, []) + return revs[-1].content if revs else "" + + def add_or_update_file(self, filename: str, new_content: Union[str, bytes, Any]) -> None: + """Create *filename* or append a new revision containing *new_content*.""" + if isinstance(new_content, bytes): + new_content = new_content.decode("utf-8") + if not isinstance(new_content, str): + raise ValueError(f"Expected str or bytes, got {type(new_content)}") + if filename not in self._files: + self._files[filename] = [FileRevision(new_content, 1)] + else: + last_rev_num = self._files[filename][-1].revision + self._files[filename].append(FileRevision(new_content, last_rev_num + 1)) + + def get_diff(self, filename: str, from_revision: int, to_revision: int) -> str: + """Return a unified diff between *from_revision* and *to_revision*.""" + revisions = self._files.get(filename, []) + if not revisions: + return "" + # Fetch the contents for the requested revisions. + from_content = self.get_revision_content(filename, from_revision) + to_content = self.get_revision_content(filename, to_revision) + if from_content == "" and to_content == "": # one (or both) revision ids not found + return "" + diff = difflib.unified_diff( + from_content.splitlines(keepends=True), + to_content.splitlines(keepends=True), + fromfile=f"{filename}@r{from_revision}", + tofile=f"{filename}@r{to_revision}", + ) + return "".join(diff) + + def apply_patch(self, filename: str, patch_data: Union[str, bytes, Any]) -> None: + """Apply *patch_text* (unified diff) to the latest revision and save a new revision. + + Uses the *unidiff* library to accurately apply hunks and validate context lines. + """ + if isinstance(patch_data, bytes): + patch_data = patch_data.decode("utf-8") + if not isinstance(patch_data, str): + raise ValueError(f"Expected str or bytes, got {type(patch_data)}") + self._ensure_file(filename) + original_content = self.get_latest_content(filename) + + if PatchSet is None: + raise ImportError( + "The 'unidiff' package is required for patch application. Install with 'pip install unidiff'." + ) + + patch = PatchSet(patch_data) + # Our canvas stores exactly one file per patch operation so we + # use the first (and only) patched_file object. + if not patch: + raise ValueError("Empty patch text provided.") + patched_file = patch[0] + working_lines = original_content.splitlines(keepends=True) + line_offset = 0 + for hunk in patched_file: + # Calculate the slice boundaries in the *current* working copy. + start = hunk.source_start - 1 + line_offset + end = start + hunk.source_length + # Build the replacement block for this hunk. + replacement: List[str] = [] + for line in hunk: + if line.is_added or line.is_context: + replacement.append(line.value) + # removed lines (line.is_removed) are *not* added. + # Replace the slice with the hunk‑result. + working_lines[start:end] = replacement + line_offset += len(replacement) - (end - start) + new_content = "".join(working_lines) + + # Finally commit the new revision. + self.add_or_update_file(filename, new_content) + + # ---------------------------------------------------------------------------------- + # Convenience helpers + # ---------------------------------------------------------------------------------- + + def get_all_contents_for_context(self) -> str: # noqa: D401 – keep public API stable + """Return a summarised view of every file and its *latest* revision.""" + out: List[str] = ["=== CANVAS FILES ==="] + for fname, revs in self._files.items(): + latest = revs[-1] + out.append(f"File: {fname} (rev {latest.revision}):\n{latest.content}\n") + out.append("=== END OF CANVAS ===") + return "\n".join(out) diff --git a/agent_dhal/agentdhal_extensions/memory/canvas/_text_canvas_memory.py b/agent_dhal/agentdhal_extensions/memory/canvas/_text_canvas_memory.py new file mode 100644 index 0000000..30eca47 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/canvas/_text_canvas_memory.py @@ -0,0 +1,229 @@ +from typing import Any, Optional + +from agentdhal_core import CancellationToken +from agentdhal_core.memory import ( + Memory, + MemoryContent, + MemoryMimeType, + MemoryQueryResult, + UpdateContextResult, +) +from agentdhal_core.model_context import ChatCompletionContext +from agentdhal_core.models import SystemMessage + +from ._canvas_writer import ApplyPatchTool, UpdateFileTool +from ._text_canvas import TextCanvas + + +class TextCanvasMemory(Memory): + """ + A memory implementation that uses a Canvas for storing file-like content. + Inserts the current state of the canvas into the ChatCompletionContext on each turn. + + .. warning:: + + This is an experimental API and may change in the future. + + The TextCanvasMemory provides a persistent, file-like storage mechanism that can be used + by agents to read and write content. It automatically injects the current state of all files + in the canvas into the model context before each inference. + + This is particularly useful for: + - Allowing agents to create and modify documents over multiple turns + - Enabling collaborative document editing between multiple agents + - Maintaining persistent state across conversation turns + - Working with content too large to fit in a single message + + The canvas provides tools for: + - Creating or updating files with new content + - Applying patches (unified diff format) to existing files + + Examples: + + **Example: Using TextCanvasMemory with an AssistantAgent** + + The following example demonstrates how to create a TextCanvasMemory and use it with + an AssistantAgent to write and update a story file. + + .. code-block:: python + + import asyncio + from agentdhal_core import CancellationToken + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.messages import TextMessage + from agentdhal_extensions.memory.canvas import TextCanvasMemory + + + async def main(): + # Create a model client + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + + # Create the canvas memory + text_canvas_memory = TextCanvasMemory() + + # Get tools for working with the canvas + update_file_tool = text_canvas_memory.get_update_file_tool() + apply_patch_tool = text_canvas_memory.get_apply_patch_tool() + + # Create an agent with the canvas memory and tools + writer_agent = AssistantAgent( + name="Writer", + model_client=model_client, + description="A writer agent that creates and updates stories.", + system_message=''' + You are a Writer Agent. Your focus is to generate a story based on the user's request. + + Instructions for using the canvas: + + - The story should be stored on the canvas in a file named "story.md". + - If "story.md" does not exist, create it by calling the 'update_file' tool. + - If "story.md" already exists, generate a unified diff (patch) from the current + content to the new version, and call the 'apply_patch' tool to apply the changes. + + IMPORTANT: Do not include the full story text in your chat messages. + Only write the story content to the canvas using the tools. + ''', + tools=[update_file_tool, apply_patch_tool], + memory=[text_canvas_memory], + ) + + # Send a message to the agent + await writer_agent.on_messages( + [TextMessage(content="Write a short story about a bunny and a sunflower.", source="user")], + CancellationToken(), + ) + + # Retrieve the content from the canvas + story_content = text_canvas_memory.canvas.get_latest_content("story.md") + print("Story content from canvas:") + print(story_content) + + + if __name__ == "__main__": + asyncio.run(main()) + + **Example: Using TextCanvasMemory with multiple agents** + + The following example shows how to use TextCanvasMemory with multiple agents + collaborating on the same document. + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_agentchat.conditions import TextMentionTermination + from agentdhal_extensions.memory.canvas import TextCanvasMemory + + + async def main(): + # Create a model client + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + + # Create the shared canvas memory + text_canvas_memory = TextCanvasMemory() + update_file_tool = text_canvas_memory.get_update_file_tool() + apply_patch_tool = text_canvas_memory.get_apply_patch_tool() + + # Create a writer agent + writer_agent = AssistantAgent( + name="Writer", + model_client=model_client, + description="A writer agent that creates stories.", + system_message="You write children's stories on the canvas in story.md.", + tools=[update_file_tool, apply_patch_tool], + memory=[text_canvas_memory], + ) + + # Create a critique agent + critique_agent = AssistantAgent( + name="Critique", + model_client=model_client, + description="A critique agent that provides feedback on stories.", + system_message="You review the story.md file and provide constructive feedback.", + memory=[text_canvas_memory], + ) + + # Create a team with both agents + team = RoundRobinGroupChat( + participants=[writer_agent, critique_agent], + termination_condition=TextMentionTermination("TERMINATE"), + max_turns=10, + ) + + # Run the team on a task + await team.run(task="Create a children's book about a bunny and a sunflower") + + # Get the final story + story = text_canvas_memory.canvas.get_latest_content("story.md") + print(story) + + + if __name__ == "__main__": + asyncio.run(main()) + """ + + def __init__(self, canvas: Optional[TextCanvas] = None): + super().__init__() + self.canvas = canvas if canvas is not None else TextCanvas() + + async def update_context(self, model_context: ChatCompletionContext) -> UpdateContextResult: + """ + Inject the entire canvas summary (or a selected subset) as reference data. + Here, we just put it into a system message, but you could customize. + """ + snapshot = self.canvas.get_all_contents_for_context() + if snapshot.strip(): + msg = SystemMessage(content=snapshot) + await model_context.add_message(msg) + + # Return it for debugging/logging + memory_content = MemoryContent(content=snapshot, mime_type=MemoryMimeType.TEXT) + return UpdateContextResult(memories=MemoryQueryResult(results=[memory_content])) + + return UpdateContextResult(memories=MemoryQueryResult(results=[])) + + async def query( + self, query: str | MemoryContent, cancellation_token: Optional[CancellationToken] = None, **kwargs: Any + ) -> MemoryQueryResult: + """ + Potentially search for matching filenames or file content. + This example returns empty. + """ + return MemoryQueryResult(results=[]) + + async def add(self, content: MemoryContent, cancellation_token: Optional[CancellationToken] = None) -> None: + """ + Example usage: Possibly interpret content as a patch or direct file update. + Could also be done by a specialized "CanvasTool" instead. + """ + # NO-OP here, leaving actual changes to the CanvasTool + pass + + async def clear(self) -> None: + """Clear the entire canvas by replacing it with a new empty instance.""" + # Create a new TextCanvas instance instead of calling __init__ directly + self.canvas = TextCanvas() + + async def close(self) -> None: + pass + + def get_update_file_tool(self) -> UpdateFileTool: + """ + Returns an UpdateFileTool instance that works with this memory's canvas. + """ + return UpdateFileTool(self.canvas) + + def get_apply_patch_tool(self) -> ApplyPatchTool: + """ + Returns an ApplyPatchTool instance that works with this memory's canvas. + """ + return ApplyPatchTool(self.canvas) diff --git a/agent_dhal/agentdhal_extensions/memory/chromadb/__init__.py b/agent_dhal/agentdhal_extensions/memory/chromadb/__init__.py new file mode 100644 index 0000000..1d6ad04 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/chromadb/__init__.py @@ -0,0 +1,21 @@ +from ._chroma_configs import ( + ChromaDBVectorMemoryConfig, + CustomEmbeddingFunctionConfig, + DefaultEmbeddingFunctionConfig, + HttpChromaDBVectorMemoryConfig, + OpenAIEmbeddingFunctionConfig, + PersistentChromaDBVectorMemoryConfig, + SentenceTransformerEmbeddingFunctionConfig, +) +from ._chromadb import ChromaDBVectorMemory + +__all__ = [ + "ChromaDBVectorMemory", + "ChromaDBVectorMemoryConfig", + "PersistentChromaDBVectorMemoryConfig", + "HttpChromaDBVectorMemoryConfig", + "DefaultEmbeddingFunctionConfig", + "SentenceTransformerEmbeddingFunctionConfig", + "OpenAIEmbeddingFunctionConfig", + "CustomEmbeddingFunctionConfig", +] diff --git a/agent_dhal/agentdhal_extensions/memory/chromadb/_chroma_configs.py b/agent_dhal/agentdhal_extensions/memory/chromadb/_chroma_configs.py new file mode 100644 index 0000000..47c16fc --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/chromadb/_chroma_configs.py @@ -0,0 +1,137 @@ +"""Configuration classes for ChromaDB vector memory.""" + +from typing import Any, Callable, Dict, Literal, Union + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + + +class DefaultEmbeddingFunctionConfig(BaseModel): + """Configuration for the default ChromaDB embedding function. + + Uses ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2). + + .. versionadded:: v0.4.1 + Support for custom embedding functions in ChromaDB memory. + """ + + function_type: Literal["default"] = "default" + + +class SentenceTransformerEmbeddingFunctionConfig(BaseModel): + """Configuration for SentenceTransformer embedding functions. + + Allows specifying a custom SentenceTransformer model for embeddings. + + .. versionadded:: v0.4.1 + Support for custom embedding functions in ChromaDB memory. + + Args: + model_name (str): Name of the SentenceTransformer model to use. + Defaults to "all-MiniLM-L6-v2". + + Example: + .. code-block:: python + + from agentdhal_extensions.memory.chromadb import SentenceTransformerEmbeddingFunctionConfig + + _ = SentenceTransformerEmbeddingFunctionConfig(model_name="paraphrase-multilingual-mpnet-base-v2") + """ + + function_type: Literal["sentence_transformer"] = "sentence_transformer" + model_name: str = Field(default="all-MiniLM-L6-v2", description="SentenceTransformer model name to use") + + +class OpenAIEmbeddingFunctionConfig(BaseModel): + """Configuration for OpenAI embedding functions. + + Uses OpenAI's embedding API for generating embeddings. + + .. versionadded:: v0.4.1 + Support for custom embedding functions in ChromaDB memory. + + Args: + api_key (str): OpenAI API key. If empty, will attempt to use environment variable. + model_name (str): OpenAI embedding model name. Defaults to "text-embedding-ada-002". + + Example: + .. code-block:: python + + from agentdhal_extensions.memory.chromadb import OpenAIEmbeddingFunctionConfig + + _ = OpenAIEmbeddingFunctionConfig(api_key="sk-...", model_name="text-embedding-3-small") + """ + + function_type: Literal["openai"] = "openai" + api_key: str = Field(default="", description="OpenAI API key") + model_name: str = Field(default="text-embedding-ada-002", description="OpenAI embedding model name") + + +class CustomEmbeddingFunctionConfig(BaseModel): + """Configuration for custom embedding functions. + + Allows using a custom function that returns a ChromaDB-compatible embedding function. + + .. versionadded:: v0.4.1 + Support for custom embedding functions in ChromaDB memory. + + .. warning:: + Configurations containing custom functions are not serializable. + + Args: + function (Callable): Function that returns a ChromaDB-compatible embedding function. + params (Dict[str, Any]): Parameters to pass to the function. + """ + + function_type: Literal["custom"] = "custom" + function: Callable[..., Any] = Field(description="Function that returns an embedding function") + params: Dict[str, Any] = Field(default_factory=dict, description="Parameters to pass to the function") + + +# Tagged union type for embedding function configurations +EmbeddingFunctionConfig = Annotated[ + Union[ + DefaultEmbeddingFunctionConfig, + SentenceTransformerEmbeddingFunctionConfig, + OpenAIEmbeddingFunctionConfig, + CustomEmbeddingFunctionConfig, + ], + Field(discriminator="function_type"), +] + + +class ChromaDBVectorMemoryConfig(BaseModel): + """Base configuration for ChromaDB-based memory implementation. + + .. versionchanged:: v0.4.1 + Added support for custom embedding functions via embedding_function_config. + """ + + client_type: Literal["persistent", "http"] + collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection") + distance_metric: str = Field(default="cosine", description="Distance metric for similarity search") + k: int = Field(default=3, description="Number of results to return in queries") + score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold") + allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client") + tenant: str = Field(default="default_tenant", description="Tenant to use") + database: str = Field(default="default_database", description="Database to use") + embedding_function_config: EmbeddingFunctionConfig = Field( + default_factory=DefaultEmbeddingFunctionConfig, description="Configuration for the embedding function" + ) + + +class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig): + """Configuration for persistent ChromaDB memory.""" + + client_type: Literal["persistent", "http"] = "persistent" + persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage") + + +class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig): + """Configuration for HTTP ChromaDB memory.""" + + client_type: Literal["persistent", "http"] = "http" + host: str = Field(default="localhost", description="Host of the remote server") + port: int = Field(default=8000, description="Port of the remote server") + ssl: bool = Field(default=False, description="Whether to use HTTPS") + headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server") diff --git a/agent_dhal/agentdhal_extensions/memory/chromadb/_chromadb.py b/agent_dhal/agentdhal_extensions/memory/chromadb/_chromadb.py new file mode 100644 index 0000000..9c33534 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/chromadb/_chromadb.py @@ -0,0 +1,459 @@ +import logging +import uuid +from typing import Any, List + +from agentdhal_core import CancellationToken, Component, Image +from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult +from agentdhal_core.model_context import ChatCompletionContext +from agentdhal_core.models import SystemMessage +from chromadb import HttpClient, PersistentClient +from chromadb.api.models.Collection import Collection +from chromadb.api.types import Document, Metadata +from typing_extensions import Self + +from ._chroma_configs import ( + ChromaDBVectorMemoryConfig, + CustomEmbeddingFunctionConfig, + DefaultEmbeddingFunctionConfig, + HttpChromaDBVectorMemoryConfig, + OpenAIEmbeddingFunctionConfig, + PersistentChromaDBVectorMemoryConfig, + SentenceTransformerEmbeddingFunctionConfig, +) + +logger = logging.getLogger(__name__) + + +try: + from chromadb.api import ClientAPI +except ImportError as e: + raise ImportError( + "To use the ChromaDBVectorMemory the chromadb extra must be installed. Run `pip install autogen-ext[chromadb]`" + ) from e + + +class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]): + """ + Store and retrieve memory using vector similarity search powered by ChromaDB. + + `ChromaDBVectorMemory` provides a vector-based memory implementation that uses ChromaDB for + storing and retrieving content based on semantic similarity. It enhances agents with the ability + to recall contextually relevant information during conversations by leveraging vector embeddings + to find similar content. + + This implementation serves as a reference for more complex memory systems using vector embeddings. + For advanced use cases requiring specialized formatting of retrieved content, users should extend + this class and override the `update_context()` method. + + This implementation requires the ChromaDB extra to be installed. Install with: + + .. code-block:: bash + + pip install "agentdhal-ext[chromadb]" + + Args: + config (ChromaDBVectorMemoryConfig | None): Configuration for the ChromaDB memory. + If None, defaults to a PersistentChromaDBVectorMemoryConfig with default values. + Two config types are supported: + * PersistentChromaDBVectorMemoryConfig: For local storage + * HttpChromaDBVectorMemoryConfig: For connecting to a remote ChromaDB server + + Example: + + .. code-block:: python + + import os + import asyncio + from pathlib import Path + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.ui import Console + from agentdhal_core.memory import MemoryContent, MemoryMimeType + from agentdhal_extensions.memory.chromadb import ( + ChromaDBVectorMemory, + PersistentChromaDBVectorMemoryConfig, + SentenceTransformerEmbeddingFunctionConfig, + OpenAIEmbeddingFunctionConfig, + ) + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + def get_weather(city: str) -> str: + return f"The weather in {city} is sunny with a high of 90°F and a low of 70°F." + + + def fahrenheit_to_celsius(fahrenheit: float) -> float: + return (fahrenheit - 32) * 5.0 / 9.0 + + + async def main() -> None: + # Use default embedding function + default_memory = ChromaDBVectorMemory( + config=PersistentChromaDBVectorMemoryConfig( + collection_name="user_preferences", + persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"), + k=3, # Return top 3 results + score_threshold=0.5, # Minimum similarity score + ) + ) + + # Using a custom SentenceTransformer model + custom_memory = ChromaDBVectorMemory( + config=PersistentChromaDBVectorMemoryConfig( + collection_name="multilingual_memory", + persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"), + embedding_function_config=SentenceTransformerEmbeddingFunctionConfig( + model_name="paraphrase-multilingual-mpnet-base-v2" + ), + ) + ) + + # Using OpenAI embeddings + openai_memory = ChromaDBVectorMemory( + config=PersistentChromaDBVectorMemoryConfig( + collection_name="openai_memory", + persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"), + embedding_function_config=OpenAIEmbeddingFunctionConfig( + api_key=os.environ["OPENAI_API_KEY"], model_name="text-embedding-3-small" + ), + ) + ) + + # Add user preferences to memory + await openai_memory.add( + MemoryContent( + content="The user prefers weather temperatures in Celsius", + mime_type=MemoryMimeType.TEXT, + metadata={"category": "preferences", "type": "units"}, + ) + ) + + # Create assistant agent with ChromaDB memory + assistant = AssistantAgent( + name="assistant", + model_client=OpenAIChatCompletionClient( + model="gpt-4.1", + ), + tools=[ + get_weather, + fahrenheit_to_celsius, + ], + max_tool_iterations=10, + memory=[openai_memory], + ) + + # The memory will automatically retrieve relevant content during conversations + await Console(assistant.run_stream(task="What's the temperature in New York?")) + + # Remember to close the memory when finished + await default_memory.close() + await custom_memory.close() + await openai_memory.close() + + + asyncio.run(main()) + + Output: + + .. code-block:: text + + ---------- TextMessage (user) ---------- + What's the temperature in New York? + ---------- MemoryQueryEvent (assistant) ---------- + [MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'type': 'units', 'category': 'preferences', 'mime_type': 'MemoryMimeType.TEXT', 'score': 0.3133561611175537, 'id': 'fb00506c-acf4-4174-93d7-2a942593f3f7'}), MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'mime_type': 'MemoryMimeType.TEXT', 'category': 'preferences', 'type': 'units', 'score': 0.3133561611175537, 'id': '34311689-b419-4e1a-8bc4-09143f356c66'})] + ---------- ToolCallRequestEvent (assistant) ---------- + [FunctionCall(id='call_7TjsFd430J1aKwU5T2w8bvdh', arguments='{"city":"New York"}', name='get_weather')] + ---------- ToolCallExecutionEvent (assistant) ---------- + [FunctionExecutionResult(content='The weather in New York is sunny with a high of 90°F and a low of 70°F.', name='get_weather', call_id='call_7TjsFd430J1aKwU5T2w8bvdh', is_error=False)] + ---------- ToolCallRequestEvent (assistant) ---------- + [FunctionCall(id='call_RTjMHEZwDXtjurEYTjDlvq9c', arguments='{"fahrenheit": 90}', name='fahrenheit_to_celsius'), FunctionCall(id='call_3mMuCK1aqtzZPTqIHPoHKxtP', arguments='{"fahrenheit": 70}', name='fahrenheit_to_celsius')] + ---------- ToolCallExecutionEvent (assistant) ---------- + [FunctionExecutionResult(content='32.22222222222222', name='fahrenheit_to_celsius', call_id='call_RTjMHEZwDXtjurEYTjDlvq9c', is_error=False), FunctionExecutionResult(content='21.11111111111111', name='fahrenheit_to_celsius', call_id='call_3mMuCK1aqtzZPTqIHPoHKxtP', is_error=False)] + ---------- TextMessage (assistant) ---------- + The temperature in New York today is sunny with a high of about 32°C and a low of about 21°C. + + """ + + component_config_schema = ChromaDBVectorMemoryConfig + component_provider_override = "agentdhal_extensions.memory.chromadb.ChromaDBVectorMemory" + + def __init__(self, config: ChromaDBVectorMemoryConfig | None = None) -> None: + self._config = config or PersistentChromaDBVectorMemoryConfig() + self._client: ClientAPI | None = None + self._collection: Collection | None = None + + @property + def collection_name(self) -> str: + """Get the name of the ChromaDB collection.""" + return self._config.collection_name + + def _create_embedding_function(self) -> Any: + """Create an embedding function based on the configuration. + + Returns: + A ChromaDB-compatible embedding function. + + Raises: + ValueError: If the embedding function type is unsupported. + ImportError: If required dependencies are not installed. + """ + try: + from chromadb.utils import embedding_functions + except ImportError as e: + raise ImportError( + "ChromaDB embedding functions not available. Ensure chromadb is properly installed." + ) from e + + config = self._config.embedding_function_config + + if isinstance(config, DefaultEmbeddingFunctionConfig): + return embedding_functions.DefaultEmbeddingFunction() + + elif isinstance(config, SentenceTransformerEmbeddingFunctionConfig): + try: + return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=config.model_name) + except Exception as e: + raise ImportError( + f"Failed to create SentenceTransformer embedding function with model '{config.model_name}'. " + f"Ensure sentence-transformers is installed and the model is available. Error: {e}" + ) from e + + elif isinstance(config, OpenAIEmbeddingFunctionConfig): + try: + return embedding_functions.OpenAIEmbeddingFunction(api_key=config.api_key, model_name=config.model_name) + except Exception as e: + raise ImportError( + f"Failed to create OpenAI embedding function with model '{config.model_name}'. " + f"Ensure openai is installed and API key is valid. Error: {e}" + ) from e + + elif isinstance(config, CustomEmbeddingFunctionConfig): + try: + return config.function(**config.params) + except Exception as e: + raise ValueError(f"Failed to create custom embedding function. Error: {e}") from e + + else: + raise ValueError(f"Unsupported embedding function config type: {type(config)}") + + def _ensure_initialized(self) -> None: + """Ensure ChromaDB client and collection are initialized.""" + if self._client is None: + try: + from chromadb.config import Settings + + settings = Settings(allow_reset=self._config.allow_reset) + + if isinstance(self._config, PersistentChromaDBVectorMemoryConfig): + self._client = PersistentClient( + path=self._config.persistence_path, + settings=settings, + tenant=self._config.tenant, + database=self._config.database, + ) + elif isinstance(self._config, HttpChromaDBVectorMemoryConfig): + self._client = HttpClient( + host=self._config.host, + port=self._config.port, + ssl=self._config.ssl, + headers=self._config.headers, + settings=settings, + tenant=self._config.tenant, + database=self._config.database, + ) + else: + raise ValueError(f"Unsupported config type: {type(self._config)}") + except Exception as e: + logger.error(f"Failed to initialize ChromaDB client: {e}") + raise + + if self._collection is None: + try: + # Create embedding function + embedding_function = self._create_embedding_function() + + # Create or get collection with embedding function + self._collection = self._client.get_or_create_collection( + name=self._config.collection_name, + metadata={"distance_metric": self._config.distance_metric}, + embedding_function=embedding_function, + ) + except Exception as e: + logger.error(f"Failed to get/create collection: {e}") + raise + + def _extract_text(self, content_item: str | MemoryContent) -> str: + """Extract searchable text from content.""" + if isinstance(content_item, str): + return content_item + + content = content_item.content + mime_type = content_item.mime_type + + if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]: + return str(content) + elif mime_type == MemoryMimeType.JSON: + if isinstance(content, dict): + # Store original JSON string representation + return str(content).lower() + raise ValueError("JSON content must be a dict") + elif isinstance(content, Image): + raise ValueError("Image content cannot be converted to text") + else: + raise ValueError(f"Unsupported content type: {mime_type}") + + def _calculate_score(self, distance: float) -> float: + """Convert ChromaDB distance to a similarity score.""" + if self._config.distance_metric == "cosine": + return 1.0 - (distance / 2.0) + return 1.0 / (1.0 + distance) + + async def update_context( + self, + model_context: ChatCompletionContext, + ) -> UpdateContextResult: + messages = await model_context.get_messages() + if not messages: + return UpdateContextResult(memories=MemoryQueryResult(results=[])) + + # Extract query from last message + last_message = messages[-1] + query_text = last_message.content if isinstance(last_message.content, str) else str(last_message) + + # Query memory and get results + query_results = await self.query(query_text) + + if query_results.results: + # Format results for context + memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)] + memory_context = "\nRelevant memory content:\n" + "\n".join(memory_strings) + + # Add to context + await model_context.add_message(SystemMessage(content=memory_context)) + + return UpdateContextResult(memories=query_results) + + async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: + self._ensure_initialized() + if self._collection is None: + raise RuntimeError("Failed to initialize ChromaDB") + + try: + # Extract text from content + text = self._extract_text(content) + + # Use metadata directly from content + metadata_dict = content.metadata or {} + metadata_dict["mime_type"] = str(content.mime_type) + + # Add to ChromaDB + self._collection.add(documents=[text], metadatas=[metadata_dict], ids=[str(uuid.uuid4())]) + + except Exception as e: + logger.error(f"Failed to add content to ChromaDB: {e}") + raise + + async def query( + self, + query: str | MemoryContent, + cancellation_token: CancellationToken | None = None, + **kwargs: Any, + ) -> MemoryQueryResult: + self._ensure_initialized() + if self._collection is None: + raise RuntimeError("Failed to initialize ChromaDB") + + try: + # Extract text for query + query_text = self._extract_text(query) + + # Query ChromaDB + results = self._collection.query( + query_texts=[query_text], + n_results=self._config.k, + include=["documents", "metadatas", "distances"], + **kwargs, + ) + + # Convert results to MemoryContent list + memory_results: List[MemoryContent] = [] + + if ( + not results + or not results.get("documents") + or not results.get("metadatas") + or not results.get("distances") + ): + return MemoryQueryResult(results=memory_results) + + documents: List[Document] = results["documents"][0] if results["documents"] else [] + metadatas: List[Metadata] = results["metadatas"][0] if results["metadatas"] else [] + distances: List[float] = results["distances"][0] if results["distances"] else [] + ids: List[str] = results["ids"][0] if results["ids"] else [] + + for doc, metadata_dict, distance, doc_id in zip(documents, metadatas, distances, ids, strict=False): + # Calculate score + score = self._calculate_score(distance) + metadata = dict(metadata_dict) + metadata["score"] = score + metadata["id"] = doc_id + if self._config.score_threshold is not None and score < self._config.score_threshold: + continue + + # Extract mime_type from metadata + mime_type = str(metadata_dict.get("mime_type", MemoryMimeType.TEXT.value)) + + # Create MemoryContent + content = MemoryContent( + content=doc, + mime_type=mime_type, + metadata=metadata, + ) + memory_results.append(content) + + return MemoryQueryResult(results=memory_results) + + except Exception as e: + logger.error(f"Failed to query ChromaDB: {e}") + raise + + async def clear(self) -> None: + self._ensure_initialized() + if self._collection is None: + raise RuntimeError("Failed to initialize ChromaDB") + + try: + results = self._collection.get() + if results and results["ids"]: + self._collection.delete(ids=results["ids"]) + except Exception as e: + logger.error(f"Failed to clear ChromaDB collection: {e}") + raise + + async def close(self) -> None: + """Clean up ChromaDB client and resources.""" + self._collection = None + self._client = None + + async def reset(self) -> None: + self._ensure_initialized() + if not self._config.allow_reset: + raise RuntimeError("Reset not allowed. Set allow_reset=True in config to enable.") + + if self._client is not None: + try: + self._client.reset() + except Exception as e: + logger.error(f"Error during ChromaDB reset: {e}") + finally: + self._collection = None + + def _to_config(self) -> ChromaDBVectorMemoryConfig: + """Serialize the memory configuration.""" + + return self._config + + @classmethod + def _from_config(cls, config: ChromaDBVectorMemoryConfig) -> Self: + """Deserialize the memory configuration.""" + + return cls(config=config) diff --git a/agent_dhal/agentdhal_extensions/memory/mem0/__init__.py b/agent_dhal/agentdhal_extensions/memory/mem0/__init__.py new file mode 100644 index 0000000..2f1af25 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/mem0/__init__.py @@ -0,0 +1,6 @@ +from ._mem0 import Mem0Memory, Mem0MemoryConfig + +__all__ = [ + "Mem0Memory", + "Mem0MemoryConfig", +] diff --git a/agent_dhal/agentdhal_extensions/memory/mem0/_mem0.py b/agent_dhal/agentdhal_extensions/memory/mem0/_mem0.py new file mode 100644 index 0000000..5649ef3 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/mem0/_mem0.py @@ -0,0 +1,449 @@ +import io +import logging +import uuid +from contextlib import redirect_stderr, redirect_stdout +from datetime import datetime +from typing import Any, Dict, List, Optional, TypedDict, cast + +from agentdhal_core import CancellationToken, Component, ComponentBase +from agentdhal_core.memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult +from agentdhal_core.model_context import ChatCompletionContext +from agentdhal_core.models import SystemMessage +from mem0 import Memory as Memory0 +from mem0 import MemoryClient +from pydantic import BaseModel, Field +from typing_extensions import Self + +logger = logging.getLogger(__name__) +logging.getLogger("chromadb").setLevel(logging.ERROR) + + +class Mem0MemoryConfig(BaseModel): + """Configuration for Mem0Memory component.""" + + user_id: Optional[str] = Field( + default=None, description="User ID for memory operations. If not provided, a UUID will be generated." + ) + limit: int = Field(default=10, description="Maximum number of results to return in memory queries.") + is_cloud: bool = Field(default=True, description="Whether to use cloud Mem0 client (True) or local client (False).") + api_key: Optional[str] = Field( + default=None, description="API key for cloud Mem0 client. Required if is_cloud=True." + ) + config: Optional[Dict[str, Any]] = Field( + default=None, description="Configuration dictionary for local Mem0 client. Required if is_cloud=False." + ) + + +class MemoryResult(TypedDict, total=False): + memory: str + score: float + metadata: Dict[str, Any] + created_at: str + updated_at: str + categories: List[str] + + +# pyright: reportGeneralTypeIssues=false +class Mem0Memory(Memory, Component[Mem0MemoryConfig], ComponentBase[Mem0MemoryConfig]): + """Mem0 memory implementation for AutoGen. + + This component integrates with Mem0.ai's memory system, providing an implementation + of AutoGen's Memory interface. It supports both cloud and local backends through the + mem0ai Python package. + + To use this component, you need to have the `mem0` (for cloud-only) or `mem0-local` (for local) + extra installed for the `autogen-ext` package: + + .. code-block:: bash + + pip install -U "agentdhal-ext[mem0]" # For cloud-based Mem0 + pip install -U "agentdhal-ext[mem0-local]" # For local Mem0 + + The memory component can store and retrieve information that agents need to remember + across conversations. It also provides context updating for language models with + relevant memories. + + Examples: + + .. code-block:: python + + import asyncio + from agentdhal_extensions.memory.mem0 import Mem0Memory + from agentdhal_core.memory import MemoryContent + + + async def main() -> None: + # Create a local Mem0Memory (no API key required) + memory = Mem0Memory( + is_cloud=False, + config={"path": ":memory:"}, # Use in-memory storage for testing + ) + print("Memory initialized successfully!") + + # Add something to memory + test_content = "User likes the color blue." + await memory.add(MemoryContent(content=test_content, mime_type="text/plain")) + print(f"Added content: {test_content}") + + # Retrieve memories with a search query + results = await memory.query("What color does the user like?") + print(f"Query results: {len(results.results)} found") + + for i, result in enumerate(results.results): + print(f"Result {i+1}: {result}") + + + asyncio.run(main()) + + Output: + + .. code-block:: text + + Memory initialized successfully! + Added content: User likes the color blue. + Query results: 1 found + Result 1: content='User likes the color blue' mime_type='text/plain' metadata={'score': 0.6977155806281953, 'created_at': datetime.datetime(2025, 7, 6, 17, 25, 18, 754725, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200)))} + + Using it with an :class:`~agentdhal_agentchat.agents.AssistantAgent`: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_core.memory import MemoryContent + from agentdhal_extensions.memory.mem0 import Mem0Memory + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + + async def main() -> None: + # Create a model client + model_client = OpenAIChatCompletionClient(model="gpt-4.1") + + # Create a Mem0 memory instance + memory = Mem0Memory( + user_id="user123", + is_cloud=False, + config={"path": ":memory:"}, # Use in-memory storage for testing + ) + + # Add something to memory + test_content = "User likes the color blue." + await memory.add(MemoryContent(content=test_content, mime_type="text/plain")) + + # Create an assistant agent with Mem0 memory + agent = AssistantAgent( + name="assistant", + model_client=model_client, + memory=[memory], + system_message="You are a helpful assistant that remembers user preferences.", + ) + + # Run a sample task + result = await agent.run(task="What color does the user like?") + print(result.messages[-1].content) # type: ignore + + + asyncio.run(main()) + + Output: + + .. code-block:: text + + User likes the color blue. + + Args: + user_id: Optional user ID for memory operations. If not provided, a UUID will be generated. + limit: Maximum number of results to return in memory queries. + is_cloud: Whether to use cloud Mem0 client (True) or local client (False). + api_key: API key for cloud Mem0 client. It will read from the environment MEM0_API_KEY if not provided. + config: Configuration dictionary for local Mem0 client. Required if is_cloud=False. + """ + + component_type = "memory" + component_provider_override = "agentdhal_extensions.memory.mem0.Mem0Memory" + component_config_schema = Mem0MemoryConfig + + def __init__( + self, + user_id: Optional[str] = None, + limit: int = 10, + is_cloud: bool = True, + api_key: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + ) -> None: + # Validate parameters + if not is_cloud and config is None: + raise ValueError("config is required when using local Mem0 client (is_cloud=False)") + + # Initialize instance variables + self._user_id = user_id or str(uuid.uuid4()) + self._limit = limit + self._is_cloud = is_cloud + self._api_key = api_key + self._config = config + + # Initialize client + if self._is_cloud: + self._client = MemoryClient(api_key=self._api_key) + else: + assert self._config is not None + config_dict = self._config + self._client = Memory0.from_config(config_dict=config_dict) # type: ignore + + @property + def user_id(self) -> str: + """Get the user ID for memory operations.""" + return self._user_id + + @property + def limit(self) -> int: + """Get the maximum number of results to return in memory queries.""" + return self._limit + + @property + def is_cloud(self) -> bool: + """Check if the Mem0 client is cloud-based.""" + return self._is_cloud + + @property + def config(self) -> Optional[Dict[str, Any]]: + """Get the configuration for the Mem0 client.""" + return self._config + + async def add( + self, + content: MemoryContent, + cancellation_token: Optional[CancellationToken] = None, + ) -> None: + """Add content to memory. + + Args: + content: The memory content to add. + cancellation_token: Optional token to cancel operation. + + Raises: + Exception: If there's an error adding content to mem0 memory. + """ + # Extract content based on mime type + if hasattr(content, "content") and hasattr(content, "mime_type"): + if content.mime_type in ["text/plain", "text/markdown"]: + message = str(content.content) + elif content.mime_type == "application/json": + # Convert JSON content to string representation + if isinstance(content.content, str): + message = content.content + else: + # Convert dict or other JSON serializable objects to string + import json + + message = json.dumps(content.content) + else: + message = str(content.content) + + # Extract metadata + metadata = content.metadata or {} + else: + # Handle case where content is directly provided as string + message = str(content) + metadata = {} + + # Check if operation is cancelled + if cancellation_token is not None and cancellation_token.cancelled: # type: ignore + return + + # Add to mem0 client + try: + user_id = metadata.pop("user_id", self._user_id) + # Suppress warning messages from mem0 MemoryClient + kwargs = {} if self._client.__class__.__name__ == "Memory" else {"output_format": "v1.1"} + with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()): + self._client.add([{"role": "user", "content": message}], user_id=user_id, metadata=metadata, **kwargs) # type: ignore + except Exception as e: + # Log the error but don't crash + logger.error(f"Error adding to mem0 memory: {str(e)}") + raise + + async def query( + self, + query: str | MemoryContent = "", + cancellation_token: Optional[CancellationToken] = None, + **kwargs: Any, + ) -> MemoryQueryResult: + """Query memory for relevant content. + + Args: + query: The query to search for, either as string or MemoryContent. + cancellation_token: Optional token to cancel operation. + **kwargs: Additional query parameters to pass to mem0. + + Returns: + MemoryQueryResult containing search results. + """ + # Extract query text + if isinstance(query, str): + query_text = query + elif hasattr(query, "content"): + query_text = str(query.content) + else: + query_text = str(query) + + # Check if operation is cancelled + if ( + cancellation_token + and hasattr(cancellation_token, "cancelled") + and getattr(cancellation_token, "cancelled", False) + ): + return MemoryQueryResult(results=[]) + + try: + limit = kwargs.pop("limit", self._limit) + with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()): + # Query mem0 client + results = self._client.search( # type: ignore + query_text, + user_id=self._user_id, + limit=limit, + **kwargs, + ) + + # Type-safe handling of results + if isinstance(results, dict) and "results" in results: + result_list = cast(List[MemoryResult], results["results"]) + else: + result_list = cast(List[MemoryResult], results) + + # Convert results to MemoryContent objects + memory_contents: List[MemoryContent] = [] + for result in result_list: + content_text = result.get("memory", "") + metadata: Dict[str, Any] = {} + + if "metadata" in result and result["metadata"]: + metadata = result["metadata"] + + # Add relevant fields to metadata + if "score" in result: + metadata["score"] = result["score"] + + # For created_at + if "created_at" in result and result.get("created_at"): + try: + metadata["created_at"] = datetime.fromisoformat(result["created_at"]) + except (ValueError, TypeError): + pass + + # For updated_at + if "updated_at" in result and result.get("updated_at"): + try: + metadata["updated_at"] = datetime.fromisoformat(result["updated_at"]) + except (ValueError, TypeError): + pass + + # For categories + if "categories" in result and result.get("categories"): + metadata["categories"] = result["categories"] + + # Create MemoryContent object + memory_content = MemoryContent( + content=content_text, + mime_type="text/plain", # Default to text/plain + metadata=metadata, + ) + memory_contents.append(memory_content) + + return MemoryQueryResult(results=memory_contents) + + except Exception as e: + # Log the error but return empty results + logger.error(f"Error querying mem0 memory: {str(e)}") + return MemoryQueryResult(results=[]) + + async def update_context( + self, + model_context: ChatCompletionContext, + ) -> UpdateContextResult: + """Update the model context with relevant memories. + + This method retrieves the conversation history from the model context, + uses the last message as a query to find relevant memories, and then + adds those memories to the context as a system message. + + Args: + model_context: The model context to update. + + Returns: + UpdateContextResult containing memories added to the context. + """ + # Get messages from context + messages = await model_context.get_messages() + if not messages: + return UpdateContextResult(memories=MemoryQueryResult(results=[])) + + # Use the last message as query + last_message = messages[-1] + query_text = last_message.content if isinstance(last_message.content, str) else str(last_message) + + # Query memory + query_results = await self.query(query_text, limit=self._limit) + + # If we have results, add them to the context + if query_results.results: + # Format memories as numbered list + memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)] + memory_context = "\nRelevant memories:\n" + "\n".join(memory_strings) + + # Add as system message + await model_context.add_message(SystemMessage(content=memory_context)) + + return UpdateContextResult(memories=query_results) + + async def clear(self) -> None: + """Clear all content from memory for the current user. + + Raises: + Exception: If there's an error clearing mem0 memory. + """ + try: + self._client.delete_all(user_id=self._user_id) # type: ignore + except Exception as e: + logger.error(f"Error clearing mem0 memory: {str(e)}") + raise + + async def close(self) -> None: + """Clean up resources if needed. + + This is a no-op for Mem0 clients as they don't require explicit cleanup. + """ + pass + + @classmethod + def _from_config(cls, config: Mem0MemoryConfig) -> Self: + """Create instance from configuration. + + Args: + config: Configuration for Mem0Memory component. + + Returns: + A new Mem0Memory instance. + """ + return cls( + user_id=config.user_id, + limit=config.limit, + is_cloud=config.is_cloud, + api_key=config.api_key, + config=config.config, + ) + + def _to_config(self) -> Mem0MemoryConfig: + """Convert instance to configuration. + + Returns: + Configuration representing this Mem0Memory instance. + """ + return Mem0MemoryConfig( + user_id=self._user_id, + limit=self._limit, + is_cloud=self._is_cloud, + api_key=self._api_key, + config=self._config, + ) diff --git a/agent_dhal/agentdhal_extensions/memory/redis/__init__.py b/agent_dhal/agentdhal_extensions/memory/redis/__init__.py new file mode 100644 index 0000000..606cf2d --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/redis/__init__.py @@ -0,0 +1,9 @@ +from ._redis_memory import ( + RedisMemory, + RedisMemoryConfig, +) + +__all__ = [ + "RedisMemoryConfig", + "RedisMemory", +] diff --git a/agent_dhal/agentdhal_extensions/memory/redis/_redis_memory.py b/agent_dhal/agentdhal_extensions/memory/redis/_redis_memory.py new file mode 100644 index 0000000..6545436 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/memory/redis/_redis_memory.py @@ -0,0 +1,325 @@ +import logging +from typing import Any, List, Literal + +from agentdhal_core import CancellationToken, Component +from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult +from agentdhal_core.model_context import ChatCompletionContext +from agentdhal_core.models import SystemMessage +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +try: + from redis import Redis + from redisvl.extensions.message_history import SemanticMessageHistory + from redisvl.utils.utils import deserialize, serialize +except ImportError as e: + raise ImportError("To use Redis Memory RedisVL must be installed. Run `pip install autogen-ext[redisvl]`") from e + + +class RedisMemoryConfig(BaseModel): + """ + Configuration for Redis-based vector memory. + + This class defines the configuration options for using Redis as a vector memory store, + supporting semantic memory. It allows customization of the Redis connection, index settings, + similarity search parameters, and embedding model. + """ + + redis_url: str = Field(default="redis://localhost:6379", description="url of the Redis instance") + index_name: str = Field(default="chat_history", description="Name of the Redis collection") + prefix: str = Field(default="memory", description="prefix of the Redis collection") + distance_metric: Literal["cosine", "ip", "l2"] = "cosine" + algorithm: Literal["flat", "hnsw"] = "flat" + top_k: int = Field(default=10, description="Number of results to return in queries") + datatype: Literal["uint8", "int8", "float16", "float32", "float64", "bfloat16"] = "float32" + distance_threshold: float = Field(default=0.7, description="Minimum similarity score threshold") + model_name: str | None = Field( + default="sentence-transformers/all-mpnet-base-v2", description="Embedding model name" + ) + + +class RedisMemory(Memory, Component[RedisMemoryConfig]): + """ + Store and retrieve memory using vector similarity search powered by RedisVL. + + `RedisMemory` provides a vector-based memory implementation that uses RedisVL for storing and + retrieving content based on semantic similarity. It enhances agents with the ability to recall + contextually relevant information during conversations by leveraging vector embeddings to find + similar content. + + This implementation requires the RedisVL extra to be installed. Install with: + + .. code-block:: bash + + pip install "agentdhal-ext[redisvl]" + + Additionally, you will need access to a Redis instance. + To run a local instance of redis in docker: + + .. code-block:: bash + + docker run -d --name redis -p 6379:6379 redis:8 + + To download and run Redis locally: + + .. code-block:: bash + + curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg + echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list + sudo apt-get update > /dev/null 2>&1 + sudo apt-get install redis-server > /dev/null 2>&1 + redis-server --daemonize yes + + Args: + config (RedisMemoryConfig | None): Configuration for the Redis memory. + If None, defaults to a RedisMemoryConfig with recommended settings. + + Example: + + .. code-block:: python + + from logging import WARNING, getLogger + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.ui import Console + from agentdhal_core.memory import MemoryContent, MemoryMimeType + from agentdhal_extensions.memory.redis import RedisMemory, RedisMemoryConfig + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + + logger = getLogger() + logger.setLevel(WARNING) + + + # Define tool to use + async def get_weather(city: str, units: str = "imperial") -> str: + if units == "imperial": + return f"The weather in {city} is 73 °F and Sunny." + elif units == "metric": + return f"The weather in {city} is 23 °C and Sunny." + else: + return f"Sorry, I don't know the weather in {city}." + + + async def main(): + # Initailize Redis memory + redis_memory = RedisMemory( + config=RedisMemoryConfig( + redis_url="redis://localhost:6379", + index_name="chat_history", + prefix="memory", + ) + ) + + # Add user preferences to memory + await redis_memory.add( + MemoryContent( + content="The weather should be in metric units", + mime_type=MemoryMimeType.TEXT, + metadata={"category": "preferences", "type": "units"}, + ) + ) + + await redis_memory.add( + MemoryContent( + content="Meal recipe must be vegan", + mime_type=MemoryMimeType.TEXT, + metadata={"category": "preferences", "type": "dietary"}, + ) + ) + + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + ) + + # Create assistant agent with ChromaDB memory + assistant_agent = AssistantAgent( + name="assistant_agent", + model_client=model_client, + tools=[get_weather], + memory=[redis_memory], + ) + + stream = assistant_agent.run_stream(task="What is the weather in New York?") + await Console(stream) + + await model_client.close() + await redis_memory.close() + + + asyncio.run(main()) + + Output: + + .. code-block:: text + + ---------- TextMessage (user) ---------- + What is the weather in New York? + ---------- MemoryQueryEvent (assistant_agent) ---------- + [MemoryContent(content='The weather should be in metric units', mime_type=, metadata={'category': 'preferences', 'type': 'units'})] + ---------- ToolCallRequestEvent (assistant_agent) ---------- + [FunctionCall(id='call_tyCPvPPAV4SHWhtfpM6UMemr', arguments='{"city":"New York","units":"metric"}', name='get_weather')] + ---------- ToolCallExecutionEvent (assistant_agent) ---------- + [FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_tyCPvPPAV4SHWhtfpM6UMemr', is_error=False)] + ---------- ToolCallSummaryMessage (assistant_agent) ---------- + The weather in New York is 23 °C and Sunny. + + """ + + component_config_schema = RedisMemoryConfig + component_provider_override = "agentdhal_extensions.memory.redis_memory.RedisMemory" + + def __init__(self, config: RedisMemoryConfig | None = None) -> None: + """Initialize RedisMemory.""" + self.config = config or RedisMemoryConfig() + client = Redis.from_url(url=self.config.redis_url) # type: ignore[reportUknownMemberType] + + self.message_history = SemanticMessageHistory(name=self.config.index_name, redis_client=client) + + async def update_context( + self, + model_context: ChatCompletionContext, + ) -> UpdateContextResult: + """ + Update the model context with relevant memory content. + + This method retrieves memory content relevant to the last message in the context + and adds it as a system message. This implementation uses the last message in the context + as a query to find semantically similar memories and adds them all to the context as a + single system message. + + Args: + model_context (ChatCompletionContext): The model context to update with relevant + memories. + + Returns: + UpdateContextResult: Object containing the memories that were used to update the + context. + """ + messages = await model_context.get_messages() + if messages: + last_message = str(messages[-1].content) + else: + last_message = "" + + query_results = await self.query(last_message) + + stringified_messages = "\n\n".join([str(m.content) for m in query_results.results]) + + await model_context.add_message(SystemMessage(content=stringified_messages)) + + return UpdateContextResult(memories=query_results) + + async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: + """Add a memory content object to Redis. + + .. note:: + + To perform semantic search over stored memories RedisMemory creates a vector embedding + from the content field of a MemoryContent object. This content is assumed to be text, + JSON, or Markdown, and is passed to the vector embedding model specified in + RedisMemoryConfig. + + Args: + content (MemoryContent): The memory content to store within Redis. + cancellation_token (CancellationToken): Token passed to cease operation. Not used. + """ + if content.mime_type == MemoryMimeType.TEXT: + memory_content = content.content + mime_type = "text/plain" + elif content.mime_type == MemoryMimeType.JSON: + memory_content = serialize(content.content) + mime_type = "application/json" + elif content.mime_type == MemoryMimeType.MARKDOWN: + memory_content = content.content + mime_type = "text/markdown" + else: + raise NotImplementedError( + f"Error: {content.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported." + ) + metadata = {"mime_type": mime_type} + metadata.update(content.metadata if content.metadata else {}) + self.message_history.add_message( + {"role": "user", "content": memory_content, "tool_call_id": serialize(metadata)} # type: ignore[reportArgumentType] + ) + + async def query( + self, + query: str | MemoryContent, + cancellation_token: CancellationToken | None = None, + **kwargs: Any, + ) -> MemoryQueryResult: + """Query memory content based on semantic vector similarity. + + .. note:: + + RedisMemory.query() supports additional keyword arguments to improve query performance. + top_k (int): The maximum number of relevant memories to include. Defaults to 10. + distance_threshold (float): The maximum distance in vector space to consider a memory + semantically similar when performining cosine similarity search. Defaults to 0.7. + + Args: + query (str | MemoryContent): query to perform vector similarity search with. If a + string is passed, a vector embedding is created from it with the model specified + in the RedisMemoryConfig. If a MemoryContent object is passed, the content field + of this object is extracted and a vector embedding is created from it with the + model specified in the RedisMemoryConfig. + cancellation_token (CancellationToken): Token passed to cease operation. Not used. + + Returns: + memoryQueryResult: Object containing memories relevant to the provided query. + """ + # get the query string, or raise an error for unsupported MemoryContent types + if isinstance(query, str): + prompt = query + elif isinstance(query, MemoryContent): + if query.mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN): + prompt = str(query.content) + elif query.mime_type == MemoryMimeType.JSON: + prompt = serialize(query.content) + else: + raise NotImplementedError( + f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, MemoryMimeType.MARKDOWN are currently supported." + ) + else: + raise TypeError("'query' must be either a string or MemoryContent") + + top_k = kwargs.pop("top_k", self.config.top_k) + distance_threshold = kwargs.pop("distance_threshold", self.config.distance_threshold) + + results = self.message_history.get_relevant( + prompt=prompt, # type: ignore[reportArgumentType] + top_k=top_k, + distance_threshold=distance_threshold, + raw=False, + ) + + memories: List[MemoryContent] = [] + for result in results: + metadata = deserialize(result["tool_call_id"]) # type: ignore[reportArgumentType] + mime_type = MemoryMimeType(metadata.pop("mime_type")) + if mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN): + memory_content = result["content"] # type: ignore[reportArgumentType] + elif mime_type == MemoryMimeType.JSON: + memory_content = deserialize(result["content"]) # type: ignore[reportArgumentType] + else: + raise NotImplementedError( + f"Error: {mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported." + ) + memory = MemoryContent( + content=memory_content, # type: ignore[reportArgumentType] + mime_type=mime_type, + metadata=metadata, + ) + memories.append(memory) # type: ignore[reportUknownMemberType] + + return MemoryQueryResult(results=memories) # type: ignore[reportUknownMemberType] + + async def clear(self) -> None: + """Clear all entries from memory, preserving the RedisMemory resources.""" + self.message_history.clear() + + async def close(self) -> None: + """Clears all entries from memory, and cleans up Redis client, index and resources.""" + self.message_history.delete() diff --git a/agent_dhal/agentdhal_extensions/py.typed b/agent_dhal/agentdhal_extensions/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/agent_dhal/agentdhal_extensions/runtimes/__init__.py b/agent_dhal/agentdhal_extensions/runtimes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/__init__.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/__init__.py new file mode 100644 index 0000000..dacfa6b --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/__init__.py @@ -0,0 +1,16 @@ +from ._worker_runtime import GrpcWorkerAgentRuntime +from ._worker_runtime_host import GrpcWorkerAgentRuntimeHost +from ._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer + +try: + import grpc # type: ignore +except ImportError as e: + raise ImportError( + "To use the GRPC runtime the grpc extra must be installed. Run `pip install autogen-ext[grpc]`" + ) from e + +__all__ = [ + "GrpcWorkerAgentRuntime", + "GrpcWorkerAgentRuntimeHost", + "GrpcWorkerAgentRuntimeHostServicer", +] diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/_constants.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/_constants.py new file mode 100644 index 0000000..6dab3ff --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/_constants.py @@ -0,0 +1,13 @@ +GRPC_IMPORT_ERROR_STR = ( + "Distributed runtime features require additional dependencies. Install them with: pip install autogen-core[grpc]" +) + +DATA_CONTENT_TYPE_ATTR = "datacontenttype" +DATA_SCHEMA_ATTR = "dataschema" +AGENT_SENDER_TYPE_ATTR = "agagentsendertype" +AGENT_SENDER_KEY_ATTR = "agagentsenderkey" +MESSAGE_KIND_ATTR = "agmsgkind" +MESSAGE_KIND_VALUE_PUBLISH = "publish" +MESSAGE_KIND_VALUE_RPC_REQUEST = "rpc_request" +MESSAGE_KIND_VALUE_RPC_RESPONSE = "rpc_response" +MESSAGE_KIND_VALUE_RPC_ERROR = "error" diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/_type_helpers.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/_type_helpers.py new file mode 100644 index 0000000..be24207 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/_type_helpers.py @@ -0,0 +1,4 @@ +from typing import Any, Sequence, Tuple + +# Had to redefine this from grpc.aio._typing as using that one was causing mypy errors +ChannelArgumentType = Sequence[Tuple[str, Any]] diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/_utils.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/_utils.py new file mode 100644 index 0000000..2179282 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/_utils.py @@ -0,0 +1,45 @@ +from agentdhal_core._subscription import Subscription +from agentdhal_core._type_prefix_subscription import TypePrefixSubscription +from agentdhal_core._type_subscription import TypeSubscription + +from .protos import agent_worker_pb2 + + +def subscription_to_proto(subscription: Subscription) -> agent_worker_pb2.Subscription: + match subscription: + case TypeSubscription(topic_type=topic_type, agent_type=agent_type, id=id): + return agent_worker_pb2.Subscription( + id=id, + typeSubscription=agent_worker_pb2.TypeSubscription(topic_type=topic_type, agent_type=agent_type), + ) + case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type, id=id): + return agent_worker_pb2.Subscription( + id=id, + typePrefixSubscription=agent_worker_pb2.TypePrefixSubscription( + topic_type_prefix=topic_type_prefix, agent_type=agent_type + ), + ) + case _: + raise ValueError("Unsupported subscription type.") + + +def subscription_from_proto(subscription: agent_worker_pb2.Subscription) -> Subscription: + oneofcase = subscription.WhichOneof("subscription") + match oneofcase: + case "typeSubscription": + type_subscription_msg: agent_worker_pb2.TypeSubscription = subscription.typeSubscription + return TypeSubscription( + topic_type=type_subscription_msg.topic_type, + agent_type=type_subscription_msg.agent_type, + id=subscription.id, + ) + + case "typePrefixSubscription": + type_prefix_subscription_msg: agent_worker_pb2.TypePrefixSubscription = subscription.typePrefixSubscription + return TypePrefixSubscription( + topic_type_prefix=type_prefix_subscription_msg.topic_type_prefix, + agent_type=type_prefix_subscription_msg.agent_type, + id=subscription.id, + ) + case None: + raise ValueError("Invalid subscription message.") diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/_worker_runtime.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/_worker_runtime.py new file mode 100644 index 0000000..730f20b --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/_worker_runtime.py @@ -0,0 +1,856 @@ +from __future__ import annotations + +import asyncio +import inspect +import json +import logging +import signal +import uuid +import warnings +from asyncio import Future, Task +from collections import defaultdict +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + ClassVar, + DefaultDict, + Dict, + List, + Literal, + Mapping, + ParamSpec, + Sequence, + Set, + Tuple, + Type, + TypeVar, + cast, +) + +from agentdhal_core import ( + JSON_DATA_CONTENT_TYPE, + PROTOBUF_DATA_CONTENT_TYPE, + Agent, + AgentId, + AgentInstantiationContext, + AgentMetadata, + AgentRuntime, + AgentType, + CancellationToken, + MessageContext, + MessageHandlerContext, + MessageSerializer, + Subscription, + TopicId, +) +from agentdhal_core._runtime_impl_helpers import SubscriptionManager, get_impl +from agentdhal_core._serialization import ( + SerializationRegistry, +) +from agentdhal_core._telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata +from google.protobuf import any_pb2 +from opentelemetry.trace import TracerProvider +from typing_extensions import Self + +from agentdhal_extensions.runtimes.grpc._utils import subscription_to_proto + +from . import _constants +from ._constants import GRPC_IMPORT_ERROR_STR +from ._type_helpers import ChannelArgumentType +from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2 + +try: + import grpc.aio +except ImportError as e: + raise ImportError(GRPC_IMPORT_ERROR_STR) from e + +if TYPE_CHECKING: + from .protos.agent_worker_pb2_grpc import AgentRpcAsyncStub + +logger = logging.getLogger("agentdhal_core") +event_logger = logging.getLogger("agentdhal_core.events") + +P = ParamSpec("P") +T = TypeVar("T", bound=Agent) + + +type_func_alias = type + + +class QueueAsyncIterable(AsyncIterator[Any], AsyncIterable[Any]): + def __init__(self, queue: asyncio.Queue[Any]) -> None: + self._queue = queue + + async def __anext__(self) -> Any: + return await self._queue.get() + + def __aiter__(self) -> AsyncIterator[Any]: + return self + + +class HostConnection: + DEFAULT_GRPC_CONFIG: ClassVar[ChannelArgumentType] = [ + ( + "grpc.service_config", + json.dumps( + { + "methodConfig": [ + { + "name": [{}], + "retryPolicy": { + "maxAttempts": 3, + "initialBackoff": "0.01s", + "maxBackoff": "5s", + "backoffMultiplier": 2, + "retryableStatusCodes": ["UNAVAILABLE"], + }, + } + ], + } + ), + ) + ] + + def __init__(self, channel: grpc.aio.Channel, stub: Any) -> None: # type: ignore + self._channel = channel + self._send_queue = asyncio.Queue[agent_worker_pb2.Message]() + self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]() + self._connection_task: Task[None] | None = None + self._stub: AgentRpcAsyncStub = stub + self._client_id = str(uuid.uuid4()) + + @property + def stub(self) -> Any: + return self._stub + + @property + def metadata(self) -> Sequence[Tuple[str, str]]: + return [("client-id", self._client_id)] + + @classmethod + async def from_host_address( + cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG + ) -> Self: + logger.info("Connecting to %s", host_address) + # Always use DEFAULT_GRPC_CONFIG and override it with provided grpc_config + merged_options = [ + (k, v) for k, v in {**dict(HostConnection.DEFAULT_GRPC_CONFIG), **dict(extra_grpc_config)}.items() + ] + + channel = grpc.aio.insecure_channel( + host_address, + options=merged_options, + ) + stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore + instance = cls(channel, stub) + + instance._connection_task = await instance._connect( + stub, instance._send_queue, instance._recv_queue, instance._client_id + ) + + return instance + + async def close(self) -> None: + if self._connection_task is None: + raise RuntimeError("Connection is not open.") + await self._channel.close() + await self._connection_task + + @staticmethod + async def _connect( + stub: Any, # AgentRpcAsyncStub + send_queue: asyncio.Queue[agent_worker_pb2.Message], + receive_queue: asyncio.Queue[agent_worker_pb2.Message], + client_id: str, + ) -> Task[None]: + from grpc.aio import StreamStreamCall + + # TODO: where do exceptions from reading the iterable go? How do we recover from those? + stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore + QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)] + ) + + await stream.wait_for_connection() + + async def read_loop() -> None: + while True: + logger.info("Waiting for message from host") + message = cast(agent_worker_pb2.Message, await stream.read()) # type: ignore + if message == grpc.aio.EOF: # type: ignore + logger.info("EOF") + break + logger.info(f"Received a message from host: {message}") + await receive_queue.put(message) + logger.info("Put message in receive queue") + + return asyncio.create_task(read_loop()) + + async def send(self, message: agent_worker_pb2.Message) -> None: + logger.info(f"Send message to host: {message}") + await self._send_queue.put(message) + logger.info("Put message in send queue") + + async def recv(self) -> agent_worker_pb2.Message: + logger.info("Getting message from queue") + return await self._recv_queue.get() + + +# TODO: Lots of types need to have protobuf equivalents: +# Core: +# - FunctionCall, CodeResult, possibly CodeBlock +# - All the types in https://github.com/microsoft/autogen/blob/main/python/packages/autogen-core/src/agentdhal_core/models/_types.py +# +# Agentchat: +# - All the types in https://github.com/microsoft/autogen/blob/main/python/packages/autogen-agentchat/src/agentdhal_agentchat/messages.py to protobufs. +# +# Ext -- +# CodeExecutor: +# - CommandLineCodeResult + + +class GrpcWorkerAgentRuntime(AgentRuntime): + """An agent runtime for running remote or cross-language agents. + + Agent messaging uses protobufs from `agent_worker.proto`_ and ``CloudEvent`` from `cloudevent.proto`_. + + Cross-language agents will additionally require all agents use shared protobuf schemas for any message types that are sent between agents. + + .. _agent_worker.proto: https://github.com/microsoft/autogen/blob/main/protos/agent_worker.proto + + .. _cloudevent.proto: https://github.com/microsoft/autogen/blob/main/protos/cloudevent.proto + + """ + + # TODO: Needs to handle agent close() call + def __init__( + self, + host_address: str, + tracer_provider: TracerProvider | None = None, + extra_grpc_config: ChannelArgumentType | None = None, + payload_serialization_format: str = JSON_DATA_CONTENT_TYPE, + ) -> None: + self._host_address = host_address + self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime")) + self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set) + self._agent_factories: Dict[ + str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]] + ] = {} + self._instantiated_agents: Dict[AgentId, Agent] = {} + self._known_namespaces: set[str] = set() + self._read_task: None | Task[None] = None + self._running = False + self._pending_requests: Dict[str, Future[Any]] = {} + self._pending_requests_lock = asyncio.Lock() + self._next_request_id = 0 + self._host_connection: HostConnection | None = None + self._background_tasks: Set[Task[Any]] = set() + self._subscription_manager = SubscriptionManager() + self._serialization_registry = SerializationRegistry() + self._extra_grpc_config = extra_grpc_config or [] + self._agent_instance_types: Dict[str, Type[Agent]] = {} + + if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}: + raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}") + + self._payload_serialization_format = payload_serialization_format + + async def start(self) -> None: + """Start the runtime in a background task.""" + if self._running: + raise ValueError("Runtime is already running.") + logger.info(f"Connecting to host: {self._host_address}") + self._host_connection = await HostConnection.from_host_address( + self._host_address, extra_grpc_config=self._extra_grpc_config + ) + logger.info("Connection established") + if self._read_task is None: + self._read_task = asyncio.create_task(self._run_read_loop()) + self._running = True + + def _raise_on_exception(self, task: Task[Any]) -> None: + exception = task.exception() + if exception is not None: + raise exception + + async def _run_read_loop(self) -> None: + logger.info("Starting read loop") + assert self._host_connection is not None + # TODO: catch exceptions and reconnect + while self._running: + try: + message = await self._host_connection.recv() + oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message") + match oneofcase: + case "request": + task = asyncio.create_task(self._process_request(message.request)) + self._background_tasks.add(task) + task.add_done_callback(self._raise_on_exception) + task.add_done_callback(self._background_tasks.discard) + case "response": + task = asyncio.create_task(self._process_response(message.response)) + self._background_tasks.add(task) + task.add_done_callback(self._raise_on_exception) + task.add_done_callback(self._background_tasks.discard) + case "cloudEvent": + task = asyncio.create_task(self._process_event(message.cloudEvent)) + self._background_tasks.add(task) + task.add_done_callback(self._raise_on_exception) + task.add_done_callback(self._background_tasks.discard) + case None: + logger.warning("No message") + except Exception as e: + logger.error("Error in read loop", exc_info=e) + + async def stop(self) -> None: + """Stop the runtime immediately.""" + if not self._running: + raise RuntimeError("Runtime is not running.") + self._running = False + # Wait for all background tasks to finish. + final_tasks_results = await asyncio.gather(*self._background_tasks, return_exceptions=True) + for task_result in final_tasks_results: + if isinstance(task_result, Exception): + logger.error("Error in background task", exc_info=task_result) + # Close the host connection. + if self._host_connection is not None: + try: + await self._host_connection.close() + except asyncio.CancelledError: + pass + # Cancel the read task. + if self._read_task is not None: + self._read_task.cancel() + try: + await self._read_task + except asyncio.CancelledError: + pass + + async def stop_when_signal(self, signals: Sequence[signal.Signals] = (signal.SIGTERM, signal.SIGINT)) -> None: + """Stop the runtime when a signal is received.""" + loop = asyncio.get_running_loop() + shutdown_event = asyncio.Event() + + def signal_handler() -> None: + logger.info("Received exit signal, shutting down gracefully...") + shutdown_event.set() + + for sig in signals: + loop.add_signal_handler(sig, signal_handler) + + # Wait for the signal to trigger the shutdown event. + await shutdown_event.wait() + + # Stop the runtime. + await self.stop() + + @property + def _known_agent_names(self) -> Set[str]: + return set(self._agent_factories.keys()) + + async def _send_message( + self, + runtime_message: agent_worker_pb2.Message, + send_type: Literal["send", "publish"], + recipient: AgentId | TopicId, + telemetry_metadata: Mapping[str, str], + ) -> None: + if self._host_connection is None: + raise RuntimeError("Host connection is not set.") + with self._trace_helper.trace_block(send_type, recipient, parent=telemetry_metadata): + await self._host_connection.send(runtime_message) + + async def send_message( + self, + message: Any, + recipient: AgentId, + *, + sender: AgentId | None = None, + cancellation_token: CancellationToken | None = None, + message_id: str | None = None, + ) -> Any: + # TODO: use message_id + if not self._running: + raise ValueError("Runtime must be running when sending message.") + if self._host_connection is None: + raise RuntimeError("Host connection is not set.") + data_type = self._serialization_registry.type_name(message) + with self._trace_helper.trace_block( + "create", recipient, parent=None, extraAttributes={"message_type": data_type} + ): + # create a new future for the result + future = asyncio.get_event_loop().create_future() + request_id = await self._get_new_request_id() + self._pending_requests[request_id] = future + serialized_message = self._serialization_registry.serialize( + message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE + ) + telemetry_metadata = get_telemetry_grpc_metadata() + runtime_message = agent_worker_pb2.Message( + request=agent_worker_pb2.RpcRequest( + request_id=request_id, + target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key), + source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None, + metadata=telemetry_metadata, + payload=agent_worker_pb2.Payload( + data_type=data_type, + data=serialized_message, + data_content_type=JSON_DATA_CONTENT_TYPE, + ), + ) + ) + + # TODO: Find a way to handle timeouts/errors + task = asyncio.create_task(self._send_message(runtime_message, "send", recipient, telemetry_metadata)) + self._background_tasks.add(task) + task.add_done_callback(self._raise_on_exception) + task.add_done_callback(self._background_tasks.discard) + return await future + + async def publish_message( + self, + message: Any, + topic_id: TopicId, + *, + sender: AgentId | None = None, + cancellation_token: CancellationToken | None = None, + message_id: str | None = None, + ) -> None: + if not self._running: + raise ValueError("Runtime must be running when publishing message.") + if self._host_connection is None: + raise RuntimeError("Host connection is not set.") + if message_id is None: + message_id = str(uuid.uuid4()) + + message_type = self._serialization_registry.type_name(message) + with self._trace_helper.trace_block( + "create", topic_id, parent=None, extraAttributes={"message_type": message_type} + ): + serialized_message = self._serialization_registry.serialize( + message, type_name=message_type, data_content_type=self._payload_serialization_format + ) + + sender_id = sender or AgentId("unknown", "unknown") + attributes = { + _constants.DATA_CONTENT_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string=self._payload_serialization_format + ), + _constants.DATA_SCHEMA_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=message_type), + _constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string=sender_id.type + ), + _constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string=sender_id.key + ), + _constants.MESSAGE_KIND_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string=_constants.MESSAGE_KIND_VALUE_PUBLISH + ), + } + + # If sending JSON we fill text_data with the serialized message + # If sending Protobuf we fill proto_data with the serialized message + # TODO: add an encoding field for serializer + + if self._payload_serialization_format == JSON_DATA_CONTENT_TYPE: + runtime_message = agent_worker_pb2.Message( + cloudEvent=cloudevent_pb2.CloudEvent( + id=message_id, + spec_version="1.0", + type=topic_id.type, + source=topic_id.source, + attributes=attributes, + # TODO: use text, or proto fields appropriately + binary_data=serialized_message, + ) + ) + else: + # We need to unpack the serialized proto back into an Any + # TODO: find a way to prevent the roundtrip serialization + any_proto = any_pb2.Any() + any_proto.ParseFromString(serialized_message) + runtime_message = agent_worker_pb2.Message( + cloudEvent=cloudevent_pb2.CloudEvent( + id=message_id, + spec_version="1.0", + type=topic_id.type, + source=topic_id.source, + attributes=attributes, + proto_data=any_proto, + ) + ) + + telemetry_metadata = get_telemetry_grpc_metadata() + task = asyncio.create_task(self._send_message(runtime_message, "publish", topic_id, telemetry_metadata)) + self._background_tasks.add(task) + task.add_done_callback(self._raise_on_exception) + task.add_done_callback(self._background_tasks.discard) + + async def save_state(self) -> Mapping[str, Any]: + raise NotImplementedError("Saving state is not yet implemented.") + + async def load_state(self, state: Mapping[str, Any]) -> None: + raise NotImplementedError("Loading state is not yet implemented.") + + async def agent_metadata(self, agent: AgentId) -> AgentMetadata: + raise NotImplementedError("Agent metadata is not yet implemented.") + + async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: + raise NotImplementedError("Agent save_state is not yet implemented.") + + async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: + raise NotImplementedError("Agent load_state is not yet implemented.") + + async def _get_new_request_id(self) -> str: + async with self._pending_requests_lock: + self._next_request_id += 1 + return str(self._next_request_id) + + async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None: + assert self._host_connection is not None + recipient = AgentId(request.target.type, request.target.key) + sender: AgentId | None = None + if request.HasField("source"): + sender = AgentId(request.source.type, request.source.key) + logging.info(f"Processing request from {sender} to {recipient}") + else: + logging.info(f"Processing request from unknown source to {recipient}") + + # Deserialize the message. + message = self._serialization_registry.deserialize( + request.payload.data, + type_name=request.payload.data_type, + data_content_type=request.payload.data_content_type, + ) + + # Get the receiving agent and prepare the message context. + rec_agent = await self._get_agent(recipient) + message_context = MessageContext( + sender=sender, + topic_id=None, + is_rpc=True, + cancellation_token=CancellationToken(), + message_id=request.request_id, + ) + + # Call the receiving agent. + try: + with MessageHandlerContext.populate_context(rec_agent.id): + with self._trace_helper.trace_block( + "process", + rec_agent.id, + parent=request.metadata, + attributes={"request_id": request.request_id}, + extraAttributes={"message_type": request.payload.data_type}, + ): + result = await rec_agent.on_message(message, ctx=message_context) + except BaseException as e: + response_message = agent_worker_pb2.Message( + response=agent_worker_pb2.RpcResponse( + request_id=request.request_id, + error=str(e), + metadata=get_telemetry_grpc_metadata(), + ), + ) + # Send the error response. + await self._host_connection.send(response_message) + return + + # Serialize the result. + result_type = self._serialization_registry.type_name(result) + serialized_result = self._serialization_registry.serialize( + result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE + ) + + # Create the response message. + response_message = agent_worker_pb2.Message( + response=agent_worker_pb2.RpcResponse( + request_id=request.request_id, + payload=agent_worker_pb2.Payload( + data_type=result_type, + data=serialized_result, + data_content_type=JSON_DATA_CONTENT_TYPE, + ), + metadata=get_telemetry_grpc_metadata(), + ) + ) + + # Send the response. + await self._host_connection.send(response_message) + + async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None: + with self._trace_helper.trace_block( + "ack", + None, + parent=response.metadata, + attributes={"request_id": response.request_id}, + extraAttributes={"message_type": response.payload.data_type}, + ): + # Deserialize the result. + result = self._serialization_registry.deserialize( + response.payload.data, + type_name=response.payload.data_type, + data_content_type=response.payload.data_content_type, + ) + # Get the future and set the result. + future = self._pending_requests.pop(response.request_id) + if len(response.error) > 0: + future.set_exception(Exception(response.error)) + else: + future.set_result(result) + + async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: + event_attributes = event.attributes + sender: AgentId | None = None + if ( + _constants.AGENT_SENDER_TYPE_ATTR in event_attributes + and _constants.AGENT_SENDER_KEY_ATTR in event_attributes + ): + sender = AgentId( + event_attributes[_constants.AGENT_SENDER_TYPE_ATTR].ce_string, + event_attributes[_constants.AGENT_SENDER_KEY_ATTR].ce_string, + ) + topic_id = TopicId(event.type, event.source) + # Get the recipients for the topic. + recipients = await self._subscription_manager.get_subscribed_recipients(topic_id) + + message_content_type = event_attributes[_constants.DATA_CONTENT_TYPE_ATTR].ce_string + message_type = event_attributes[_constants.DATA_SCHEMA_ATTR].ce_string + + if message_content_type == JSON_DATA_CONTENT_TYPE: + message = self._serialization_registry.deserialize( + event.binary_data, type_name=message_type, data_content_type=message_content_type + ) + elif message_content_type == PROTOBUF_DATA_CONTENT_TYPE: + # TODO: find a way to prevent the roundtrip serialization + proto_binary_data = event.proto_data.SerializeToString() + message = self._serialization_registry.deserialize( + proto_binary_data, type_name=message_type, data_content_type=message_content_type + ) + else: + raise ValueError(f"Unsupported message content type: {message_content_type}") + + # TODO: dont read these values in the runtime + topic_type_suffix = topic_id.type.split(":", maxsplit=1)[1] if ":" in topic_id.type else "" + is_rpc = topic_type_suffix == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST + is_marked_rpc_type = ( + _constants.MESSAGE_KIND_ATTR in event_attributes + and event_attributes[_constants.MESSAGE_KIND_ATTR].ce_string == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST + ) + if is_rpc and not is_marked_rpc_type: + warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.", stacklevel=2) + + # Send the message to each recipient. + responses: List[Awaitable[Any]] = [] + for agent_id in recipients: + if agent_id == sender: + continue + message_context = MessageContext( + sender=sender, + topic_id=topic_id, + is_rpc=is_rpc, + cancellation_token=CancellationToken(), + message_id=event.id, + ) + agent = await self._get_agent(agent_id) + with MessageHandlerContext.populate_context(agent.id): + + def stringify_attributes( + attributes: Mapping[str, cloudevent_pb2.CloudEvent.CloudEventAttributeValue], + ) -> Mapping[str, str]: + result: Dict[str, str] = {} + for key, value in attributes.items(): + item = None + match value.WhichOneof("attr"): + case "ce_boolean": + item = str(value.ce_boolean) + case "ce_integer": + item = str(value.ce_integer) + case "ce_string": + item = value.ce_string + case "ce_bytes": + item = str(value.ce_bytes) + case "ce_uri": + item = value.ce_uri + case "ce_uri_ref": + item = value.ce_uri_ref + case "ce_timestamp": + item = str(value.ce_timestamp) + case _: + raise ValueError("Unknown attribute kind") + result[key] = item + + return result + + async def send_message(agent: Agent, message_context: MessageContext) -> Any: + with self._trace_helper.trace_block( + "process", + agent.id, + parent=stringify_attributes(event.attributes), + extraAttributes={"message_type": message_type}, + ): + await agent.on_message(message, ctx=message_context) + + future = send_message(agent, message_context) + responses.append(future) + # Wait for all responses. + try: + await asyncio.gather(*responses) + except BaseException as e: + logger.error("Error handling event", exc_info=e) + + async def _register_agent_type(self, agent_type: str) -> None: + if self._host_connection is None: + raise RuntimeError("Host connection is not set.") + message = agent_worker_pb2.RegisterAgentTypeRequest(type=agent_type) + _response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent( + message, metadata=self._host_connection.metadata + ) + + async def register_factory( + self, + type: str | AgentType, + agent_factory: Callable[[], T | Awaitable[T]], + *, + expected_class: type[T] | None = None, + ) -> AgentType: + if isinstance(type, str): + type = AgentType(type) + + if type.type in self._agent_factories: + raise ValueError(f"Agent with type {type} already exists.") + if self._host_connection is None: + raise RuntimeError("Host connection is not set.") + + async def factory_wrapper() -> T: + maybe_agent_instance = agent_factory() + if inspect.isawaitable(maybe_agent_instance): + agent_instance = await maybe_agent_instance + else: + agent_instance = maybe_agent_instance + + if expected_class is not None and type_func_alias(agent_instance) != expected_class: + raise ValueError("Factory registered using the wrong type.") + + return agent_instance + + self._agent_factories[type.type] = factory_wrapper + # Send the registration request message to the host. + await self._register_agent_type(type.type) + + return type + + async def register_agent_instance( + self, + agent_instance: Agent, + agent_id: AgentId, + ) -> AgentId: + def agent_factory() -> Agent: + raise RuntimeError( + "Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent." + ) + + if agent_id in self._instantiated_agents: + raise ValueError(f"Agent with id {agent_id} already exists.") + + if agent_id.type not in self._agent_factories: + self._agent_factories[agent_id.type] = agent_factory + await self._register_agent_type(agent_id.type) + self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance) + else: + if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__: + raise ValueError("Agent factories and agent instances cannot be registered to the same type.") + if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance): + raise ValueError("Agent instances must be the same object type.") + + await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self) + self._instantiated_agents[agent_id] = agent_instance + return agent_id + + async def _invoke_agent_factory( + self, + agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], + agent_id: AgentId, + ) -> T: + with AgentInstantiationContext.populate_context((self, agent_id)): + if len(inspect.signature(agent_factory).parameters) == 0: + factory_one = cast(Callable[[], T], agent_factory) + agent = factory_one() + elif len(inspect.signature(agent_factory).parameters) == 2: + warnings.warn( + "Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.", + stacklevel=2, + ) + factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory) + agent = factory_two(self, agent_id) + else: + raise ValueError("Agent factory must take 0 or 2 arguments.") + + if inspect.isawaitable(agent): + agent = cast(T, await agent) + + return agent + + async def _get_agent(self, agent_id: AgentId) -> Agent: + if agent_id in self._instantiated_agents: + return self._instantiated_agents[agent_id] + + if agent_id.type not in self._agent_factories: + raise ValueError(f"Agent with name {agent_id.type} not found.") + + agent_factory = self._agent_factories[agent_id.type] + agent = await self._invoke_agent_factory(agent_factory, agent_id) + self._instantiated_agents[agent_id] = agent + return agent + + # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 + async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] + if id.type not in self._agent_factories: + raise LookupError(f"Agent with name {id.type} not found.") + + # TODO: check if remote + agent_instance = await self._get_agent(id) + + if not isinstance(agent_instance, type): + raise TypeError(f"Agent with name {id.type} is not of type {type.__name__}") + + return agent_instance + + async def add_subscription(self, subscription: Subscription) -> None: + if self._host_connection is None: + raise RuntimeError("Host connection is not set.") + + message = agent_worker_pb2.AddSubscriptionRequest(subscription=subscription_to_proto(subscription)) + _response: agent_worker_pb2.AddSubscriptionResponse = await self._host_connection.stub.AddSubscription( + message, metadata=self._host_connection.metadata + ) + + # Add to local subscription manager. + await self._subscription_manager.add_subscription(subscription) + + async def remove_subscription(self, id: str) -> None: + if self._host_connection is None: + raise RuntimeError("Host connection is not set.") + + message = agent_worker_pb2.RemoveSubscriptionRequest(id=id) + _response: agent_worker_pb2.RemoveSubscriptionResponse = await self._host_connection.stub.RemoveSubscription( + message, metadata=self._host_connection.metadata + ) + + await self._subscription_manager.remove_subscription(id) + + async def get( + self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True + ) -> AgentId: + return await get_impl( + id_or_type=id_or_type, + key=key, + lazy=lazy, + instance_getter=self._get_agent, + ) + + def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None: + self._serialization_registry.add_serializer(serializer) diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/_worker_runtime_host.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/_worker_runtime_host.py new file mode 100644 index 0000000..2cca982 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/_worker_runtime_host.py @@ -0,0 +1,73 @@ +import asyncio +import logging +import signal +from typing import Optional, Sequence + +from ._constants import GRPC_IMPORT_ERROR_STR +from ._type_helpers import ChannelArgumentType +from ._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer + +try: + import grpc +except ImportError as e: + raise ImportError(GRPC_IMPORT_ERROR_STR) from e +from .protos import agent_worker_pb2_grpc + +logger = logging.getLogger("agentdhal_core") + + +class GrpcWorkerAgentRuntimeHost: + def __init__(self, address: str, extra_grpc_config: Optional[ChannelArgumentType] = None) -> None: + self._server = grpc.aio.server(options=extra_grpc_config) + self._servicer = GrpcWorkerAgentRuntimeHostServicer() + agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(self._servicer, self._server) + self._server.add_insecure_port(address) + self._address = address + self._serve_task: asyncio.Task[None] | None = None + + async def _serve(self) -> None: + await self._server.start() + logger.info(f"Server started at {self._address}.") + await self._server.wait_for_termination() + + def start(self) -> None: + """Start the server in a background task.""" + if self._serve_task is not None: + raise RuntimeError("Host runtime is already started.") + self._serve_task = asyncio.create_task(self._serve()) + + async def stop(self, grace: int = 5) -> None: + """Stop the server.""" + if self._serve_task is None: + raise RuntimeError("Host runtime is not started.") + await self._server.stop(grace=grace) + self._serve_task.cancel() + try: + await self._serve_task + except asyncio.CancelledError: + pass + logger.info("Server stopped.") + self._serve_task = None + + async def stop_when_signal( + self, grace: int = 5, signals: Sequence[signal.Signals] = (signal.SIGTERM, signal.SIGINT) + ) -> None: + """Stop the server when a signal is received.""" + if self._serve_task is None: + raise RuntimeError("Host runtime is not started.") + # Set up signal handling for graceful shutdown. + loop = asyncio.get_running_loop() + shutdown_event = asyncio.Event() + + def signal_handler() -> None: + logger.info("Received exit signal, shutting down gracefully...") + shutdown_event.set() + + for sig in signals: + loop.add_signal_handler(sig, signal_handler) + + # Wait for the signal to trigger the shutdown event. + await shutdown_event.wait() + + # Shutdown the server. + await self.stop(grace=grace) diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/_worker_runtime_host_servicer.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/_worker_runtime_host_servicer.py new file mode 100644 index 0000000..a5004db --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/_worker_runtime_host_servicer.py @@ -0,0 +1,364 @@ +from __future__ import annotations + +import asyncio +import logging +from abc import ABC, abstractmethod +from asyncio import Future, Task +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar + +from agentdhal_core import TopicId +from agentdhal_core._agent_id import AgentId +from agentdhal_core._runtime_impl_helpers import SubscriptionManager + +from ._constants import GRPC_IMPORT_ERROR_STR +from ._utils import subscription_from_proto, subscription_to_proto + +try: + import grpc +except ImportError as e: + raise ImportError(GRPC_IMPORT_ERROR_STR) from e + +from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2 + +logger = logging.getLogger("agentdhal_core") +event_logger = logging.getLogger("agentdhal_core.events") + +ClientConnectionId = str + + +def metadata_to_dict(metadata: Sequence[Tuple[str, str]] | None) -> Dict[str, str]: + if metadata is None: + return {} + return {key: value for key, value in metadata} + + +async def get_client_id_or_abort(context: grpc.aio.ServicerContext[Any, Any]) -> str: # type: ignore + # The type hint on context.invocation_metadata() is incorrect. + metadata = metadata_to_dict(context.invocation_metadata()) # type: ignore + if (client_id := metadata.get("client-id")) is None: + await context.abort(grpc.StatusCode.INVALID_ARGUMENT, "client-id metadata not found.") + + return client_id # type: ignore + + +SendT = TypeVar("SendT") +ReceiveT = TypeVar("ReceiveT") + + +class ChannelConnection(ABC, Generic[SendT, ReceiveT]): + def __init__(self, request_iterator: AsyncIterator[ReceiveT], client_id: str) -> None: + self._request_iterator = request_iterator + self._client_id = client_id + self._send_queue: asyncio.Queue[SendT] = asyncio.Queue() + self._receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator)) + + async def _receive_messages(self, client_id: ClientConnectionId, request_iterator: AsyncIterator[ReceiveT]) -> None: + # Receive messages from the client and process them. + async for message in request_iterator: + logger.info(f"Received message from client {client_id}: {message}") + await self._handle_message(message) + + def __aiter__(self) -> AsyncIterator[SendT]: + return self + + async def __anext__(self) -> SendT: + try: + return await self._send_queue.get() + except StopAsyncIteration: + await self._receiving_task + raise + except Exception as e: + logger.error(f"Failed to get message from send queue: {e}", exc_info=True) + await self._receiving_task + raise + + @abstractmethod + async def _handle_message(self, message: ReceiveT) -> None: + pass + + async def send(self, message: SendT) -> None: + await self._send_queue.put(message) + + +class CallbackChannelConnection(ChannelConnection[SendT, ReceiveT]): + def __init__( + self, + request_iterator: AsyncIterator[ReceiveT], + client_id: str, + handle_callback: Callable[[ReceiveT], Awaitable[None]], + ) -> None: + self._handle_callback = handle_callback + super().__init__(request_iterator, client_id) + + async def _handle_message(self, message: ReceiveT) -> None: + await self._handle_callback(message) + + +class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer): + """A gRPC servicer that hosts message delivery service for agents.""" + + def __init__(self) -> None: + self._data_connections: Dict[ + ClientConnectionId, ChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message] + ] = {} + self._control_connections: Dict[ + ClientConnectionId, ChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage] + ] = {} + self._agent_type_to_client_id_lock = asyncio.Lock() + self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {} + self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {} + self._background_tasks: Set[Task[Any]] = set() + self._subscription_manager = SubscriptionManager() + self._client_id_to_subscription_id_mapping: Dict[ClientConnectionId, set[str]] = {} + + async def OpenChannel( # type: ignore + self, + request_iterator: AsyncIterator[agent_worker_pb2.Message], + context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message], + ) -> AsyncIterator[agent_worker_pb2.Message]: + client_id = await get_client_id_or_abort(context) + + async def handle_callback(message: agent_worker_pb2.Message) -> None: + await self._receive_message(client_id, message) + + connection = CallbackChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message]( + request_iterator, client_id, handle_callback=handle_callback + ) + self._data_connections[client_id] = connection + logger.info(f"Client {client_id} connected.") + + try: + async for message in connection: + yield message + finally: + # Clean up the client connection. + del self._data_connections[client_id] + # Cancel pending requests sent to this client. + for future in self._pending_responses.pop(client_id, {}).values(): + future.cancel() + # Remove the client id from the agent type to client id mapping. + await self._on_client_disconnect(client_id) + + async def OpenControlChannel( # type: ignore + self, + request_iterator: AsyncIterator[agent_worker_pb2.ControlMessage], + context: grpc.aio.ServicerContext[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage], + ) -> AsyncIterator[agent_worker_pb2.ControlMessage]: + client_id = await get_client_id_or_abort(context) + + async def handle_callback(message: agent_worker_pb2.ControlMessage) -> None: + await self._receive_control_message(client_id, message) + + connection = CallbackChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage]( + request_iterator, client_id, handle_callback=handle_callback + ) + self._control_connections[client_id] = connection + logger.info(f"Client {client_id} connected.") + + try: + async for message in connection: + yield message + finally: + # Clean up the client connection. + del self._control_connections[client_id] + + async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None: + async with self._agent_type_to_client_id_lock: + agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id] + for agent_type in agent_types: + logger.info(f"Removing agent type {agent_type} from agent type to client id mapping") + del self._agent_type_to_client_id[agent_type] + for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()): + logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}") + try: + await self._subscription_manager.remove_subscription(sub_id) + # Catch and ignore if the subscription does not exist. + except ValueError: + continue + logger.info(f"Client {client_id} disconnected successfully") + + def _raise_on_exception(self, task: Task[Any]) -> None: + exception = task.exception() + if exception is not None: + raise exception + + async def _receive_message(self, client_id: ClientConnectionId, message: agent_worker_pb2.Message) -> None: + logger.info(f"Received message from client {client_id}: {message}") + oneofcase = message.WhichOneof("message") + match oneofcase: + case "request": + request: agent_worker_pb2.RpcRequest = message.request + task = asyncio.create_task(self._process_request(request, client_id)) + self._background_tasks.add(task) + task.add_done_callback(self._raise_on_exception) + task.add_done_callback(self._background_tasks.discard) + case "response": + response: agent_worker_pb2.RpcResponse = message.response + task = asyncio.create_task(self._process_response(response, client_id)) + self._background_tasks.add(task) + task.add_done_callback(self._raise_on_exception) + task.add_done_callback(self._background_tasks.discard) + case "cloudEvent": + task = asyncio.create_task(self._process_event(message.cloudEvent)) + self._background_tasks.add(task) + task.add_done_callback(self._raise_on_exception) + task.add_done_callback(self._background_tasks.discard) + case None: + logger.warning("Received empty message") + + async def _receive_control_message( + self, client_id: ClientConnectionId, message: agent_worker_pb2.ControlMessage + ) -> None: + logger.info(f"Received message from client {client_id}: {message}") + destination = message.destination + if destination.startswith("agentid="): + agent_id = AgentId.from_str(destination[len("agentid=") :]) + target_client_id = self._agent_type_to_client_id.get(agent_id.type) + if target_client_id is None: + logger.error(f"Agent client id not found for agent type {agent_id.type}.") + return + elif destination.startswith("clientid="): + target_client_id = destination[len("clientid=") :] + else: + logger.error(f"Invalid destination {destination}") + return + + target_send_queue = self._control_connections.get(target_client_id) + if target_send_queue is None: + logger.error(f"Client {target_client_id} not found, failed to deliver message.") + return + await target_send_queue.send(message) + + async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None: + # Deliver the message to a client given the target agent type. + async with self._agent_type_to_client_id_lock: + target_client_id = self._agent_type_to_client_id.get(request.target.type) + if target_client_id is None: + logger.error(f"Agent {request.target.type} not found, failed to deliver message.") + return + target_send_queue = self._data_connections.get(target_client_id) + if target_send_queue is None: + logger.error(f"Client {target_client_id} not found, failed to deliver message.") + return + await target_send_queue.send(agent_worker_pb2.Message(request=request)) + + # Create a future to wait for the response from the target. + future = asyncio.get_event_loop().create_future() + self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future + + # Create a task to wait for the response and send it back to the client. + send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id)) + self._background_tasks.add(send_response_task) + send_response_task.add_done_callback(self._raise_on_exception) + send_response_task.add_done_callback(self._background_tasks.discard) + + async def _wait_and_send_response( + self, future: Future[agent_worker_pb2.RpcResponse], client_id: ClientConnectionId + ) -> None: + response = await future + message = agent_worker_pb2.Message(response=response) + send_queue = self._data_connections.get(client_id) + if send_queue is None: + logger.error(f"Client {client_id} not found, failed to send response message.") + return + await send_queue.send(message) + + async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None: + # Setting the result of the future will send the response back to the original sender. + future = self._pending_responses[client_id].pop(response.request_id) + future.set_result(response) + + async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: + topic_id = TopicId(type=event.type, source=event.source) + recipients = await self._subscription_manager.get_subscribed_recipients(topic_id) + # Get the client ids of the recipients. + async with self._agent_type_to_client_id_lock: + client_ids: Set[ClientConnectionId] = set() + for recipient in recipients: + client_id = self._agent_type_to_client_id.get(recipient.type) + if client_id is not None: + client_ids.add(client_id) + else: + logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.") + # Deliver the event to clients. + for client_id in client_ids: + await self._data_connections[client_id].send(agent_worker_pb2.Message(cloudEvent=event)) + + async def RegisterAgent( # type: ignore + self, + request: agent_worker_pb2.RegisterAgentTypeRequest, + context: grpc.aio.ServicerContext[ + agent_worker_pb2.RegisterAgentTypeRequest, agent_worker_pb2.RegisterAgentTypeResponse + ], + ) -> agent_worker_pb2.RegisterAgentTypeResponse: + client_id = await get_client_id_or_abort(context) + + async with self._agent_type_to_client_id_lock: + if request.type in self._agent_type_to_client_id: + existing_client_id = self._agent_type_to_client_id[request.type] + await context.abort( + grpc.StatusCode.INVALID_ARGUMENT, + f"Agent type {request.type} already registered with client {existing_client_id}.", + ) + else: + self._agent_type_to_client_id[request.type] = client_id + + return agent_worker_pb2.RegisterAgentTypeResponse() + + async def AddSubscription( # type: ignore + self, + request: agent_worker_pb2.AddSubscriptionRequest, + context: grpc.aio.ServicerContext[ + agent_worker_pb2.AddSubscriptionRequest, agent_worker_pb2.AddSubscriptionResponse + ], + ) -> agent_worker_pb2.AddSubscriptionResponse: + client_id = await get_client_id_or_abort(context) + + subscription = subscription_from_proto(request.subscription) + try: + await self._subscription_manager.add_subscription(subscription) + subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set()) + subscription_ids.add(subscription.id) + except ValueError as e: + await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + return agent_worker_pb2.AddSubscriptionResponse() + + async def RemoveSubscription( # type: ignore + self, + request: agent_worker_pb2.RemoveSubscriptionRequest, + context: grpc.aio.ServicerContext[ + agent_worker_pb2.RemoveSubscriptionRequest, agent_worker_pb2.RemoveSubscriptionResponse + ], + ) -> agent_worker_pb2.RemoveSubscriptionResponse: + _client_id = await get_client_id_or_abort(context) + await self._subscription_manager.remove_subscription(request.id) + return agent_worker_pb2.RemoveSubscriptionResponse() + + async def GetSubscriptions( # type: ignore + self, + request: agent_worker_pb2.GetSubscriptionsRequest, + context: grpc.aio.ServicerContext[ + agent_worker_pb2.GetSubscriptionsRequest, agent_worker_pb2.GetSubscriptionsResponse + ], + ) -> agent_worker_pb2.GetSubscriptionsResponse: + _client_id = await get_client_id_or_abort(context) + subscriptions = self._subscription_manager.subscriptions + return agent_worker_pb2.GetSubscriptionsResponse( + subscriptions=[subscription_to_proto(sub) for sub in subscriptions] + ) + + # async def GetState( # type: ignore + # self, + # request: agent_worker_pb2.AgentId, + # context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse], + # ) -> agent_worker_pb2.GetStateResponse: + # _client_id = await get_client_id_or_abort(context) + # raise NotImplementedError("Method not implemented!") + + # async def SaveState( # type: ignore + # self, + # request: agent_worker_pb2.AgentState, + # context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse], + # ) -> agent_worker_pb2.SaveStateResponse: + # _client_id = await get_client_id_or_abort(context) + # raise NotImplementedError("Method not implemented!") diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/__init__.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/__init__.py new file mode 100644 index 0000000..000d43f --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/__init__.py @@ -0,0 +1,4 @@ +""" +The :mod:`agentdhal_extensions.runtimes.grpc.protos` module provides Google Protobuf classes for agent-worker communication +""" + diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2.py new file mode 100644 index 0000000..54209d2 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: agent_worker.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'agent_worker.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from . import cloudevent_pb2 as cloudevent__pb2 +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"(\n\x18RegisterAgentTypeRequest\x12\x0c\n\x04type\x18\x01 \x01(\t\"\x1b\n\x19RegisterAgentTypeResponse\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\xa2\x01\n\x0cSubscription\x12\n\n\x02id\x18\x01 \x01(\t\x12\x34\n\x10typeSubscription\x18\x02 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x03 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"D\n\x16\x41\x64\x64SubscriptionRequest\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\x19\n\x17\x41\x64\x64SubscriptionResponse\"\'\n\x19RemoveSubscriptionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\x1c\n\x1aRemoveSubscriptionResponse\"\x19\n\x17GetSubscriptionsRequest\"G\n\x18GetSubscriptionsResponse\x12+\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x14.agents.Subscription\"\x99\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x33\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x1d.io.cloudevents.v1.CloudEventH\x00\x42\t\n\x07message\"4\n\x10SaveStateRequest\x12 \n\x07\x61gentId\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\"@\n\x11SaveStateResponse\x12\r\n\x05state\x18\x01 \x01(\t\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"C\n\x10LoadStateRequest\x12 \n\x07\x61gentId\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\r\n\x05state\x18\x02 \x01(\t\"1\n\x11LoadStateResponse\x12\x12\n\x05\x65rror\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x87\x01\n\x0e\x43ontrolMessage\x12\x0e\n\x06rpc_id\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65stination\x18\x02 \x01(\t\x12\x17\n\nrespond_to\x18\x03 \x01(\tH\x00\x88\x01\x01\x12(\n\nrpcMessage\x18\x04 \x01(\x0b\x32\x14.google.protobuf.AnyB\r\n\x0b_respond_to2\xe7\x03\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12H\n\x12OpenControlChannel\x12\x16.agents.ControlMessage\x1a\x16.agents.ControlMessage(\x01\x30\x01\x12T\n\rRegisterAgent\x12 .agents.RegisterAgentTypeRequest\x1a!.agents.RegisterAgentTypeResponse\x12R\n\x0f\x41\x64\x64Subscription\x12\x1e.agents.AddSubscriptionRequest\x1a\x1f.agents.AddSubscriptionResponse\x12[\n\x12RemoveSubscription\x12!.agents.RemoveSubscriptionRequest\x1a\".agents.RemoveSubscriptionResponse\x12U\n\x10GetSubscriptions\x12\x1f.agents.GetSubscriptionsRequest\x1a .agents.GetSubscriptionsResponseB\x1d\xaa\x02\x1aMicrosoft.AutoGen.Protobufb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'agent_worker_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\252\002\032Microsoft.AutoGen.Protobuf' + _globals['_RPCREQUEST_METADATAENTRY']._loaded_options = None + _globals['_RPCREQUEST_METADATAENTRY']._serialized_options = b'8\001' + _globals['_RPCRESPONSE_METADATAENTRY']._loaded_options = None + _globals['_RPCRESPONSE_METADATAENTRY']._serialized_options = b'8\001' + _globals['_AGENTID']._serialized_start=75 + _globals['_AGENTID']._serialized_end=111 + _globals['_PAYLOAD']._serialized_start=113 + _globals['_PAYLOAD']._serialized_end=182 + _globals['_RPCREQUEST']._serialized_start=185 + _globals['_RPCREQUEST']._serialized_end=450 + _globals['_RPCREQUEST_METADATAENTRY']._serialized_start=392 + _globals['_RPCREQUEST_METADATAENTRY']._serialized_end=439 + _globals['_RPCRESPONSE']._serialized_start=453 + _globals['_RPCRESPONSE']._serialized_end=637 + _globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=392 + _globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=439 + _globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=639 + _globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=679 + _globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=681 + _globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=708 + _globals['_TYPESUBSCRIPTION']._serialized_start=710 + _globals['_TYPESUBSCRIPTION']._serialized_end=768 + _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_start=770 + _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_end=841 + _globals['_SUBSCRIPTION']._serialized_start=844 + _globals['_SUBSCRIPTION']._serialized_end=1006 + _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=1008 + _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=1076 + _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=1078 + _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=1103 + _globals['_REMOVESUBSCRIPTIONREQUEST']._serialized_start=1105 + _globals['_REMOVESUBSCRIPTIONREQUEST']._serialized_end=1144 + _globals['_REMOVESUBSCRIPTIONRESPONSE']._serialized_start=1146 + _globals['_REMOVESUBSCRIPTIONRESPONSE']._serialized_end=1174 + _globals['_GETSUBSCRIPTIONSREQUEST']._serialized_start=1176 + _globals['_GETSUBSCRIPTIONSREQUEST']._serialized_end=1201 + _globals['_GETSUBSCRIPTIONSRESPONSE']._serialized_start=1203 + _globals['_GETSUBSCRIPTIONSRESPONSE']._serialized_end=1274 + _globals['_MESSAGE']._serialized_start=1277 + _globals['_MESSAGE']._serialized_end=1430 + _globals['_SAVESTATEREQUEST']._serialized_start=1432 + _globals['_SAVESTATEREQUEST']._serialized_end=1484 + _globals['_SAVESTATERESPONSE']._serialized_start=1486 + _globals['_SAVESTATERESPONSE']._serialized_end=1550 + _globals['_LOADSTATEREQUEST']._serialized_start=1552 + _globals['_LOADSTATEREQUEST']._serialized_end=1619 + _globals['_LOADSTATERESPONSE']._serialized_start=1621 + _globals['_LOADSTATERESPONSE']._serialized_end=1670 + _globals['_CONTROLMESSAGE']._serialized_start=1673 + _globals['_CONTROLMESSAGE']._serialized_end=1808 + _globals['_AGENTRPC']._serialized_start=1811 + _globals['_AGENTRPC']._serialized_end=2298 +# @@protoc_insertion_point(module_scope) diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2.pyi b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2.pyi new file mode 100644 index 0000000..a12c53e --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2.pyi @@ -0,0 +1,457 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" + +import builtins +from . import cloudevent_pb2 +import collections.abc +import google.protobuf.any_pb2 +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing.final +class AgentId(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TYPE_FIELD_NUMBER: builtins.int + KEY_FIELD_NUMBER: builtins.int + type: builtins.str + key: builtins.str + def __init__( + self, + *, + type: builtins.str = ..., + key: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "type", b"type"]) -> None: ... + +global___AgentId = AgentId + +@typing.final +class Payload(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATA_TYPE_FIELD_NUMBER: builtins.int + DATA_CONTENT_TYPE_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + data_type: builtins.str + data_content_type: builtins.str + data: builtins.bytes + def __init__( + self, + *, + data_type: builtins.str = ..., + data_content_type: builtins.str = ..., + data: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"]) -> None: ... + +global___Payload = Payload + +@typing.final +class RpcRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class MetadataEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + + REQUEST_ID_FIELD_NUMBER: builtins.int + SOURCE_FIELD_NUMBER: builtins.int + TARGET_FIELD_NUMBER: builtins.int + METHOD_FIELD_NUMBER: builtins.int + PAYLOAD_FIELD_NUMBER: builtins.int + METADATA_FIELD_NUMBER: builtins.int + request_id: builtins.str + method: builtins.str + @property + def source(self) -> global___AgentId: ... + @property + def target(self) -> global___AgentId: ... + @property + def payload(self) -> global___Payload: ... + @property + def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... + def __init__( + self, + *, + request_id: builtins.str = ..., + source: global___AgentId | None = ..., + target: global___AgentId | None = ..., + method: builtins.str = ..., + payload: global___Payload | None = ..., + metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ... + +global___RpcRequest = RpcRequest + +@typing.final +class RpcResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class MetadataEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + + REQUEST_ID_FIELD_NUMBER: builtins.int + PAYLOAD_FIELD_NUMBER: builtins.int + ERROR_FIELD_NUMBER: builtins.int + METADATA_FIELD_NUMBER: builtins.int + request_id: builtins.str + error: builtins.str + @property + def payload(self) -> global___Payload: ... + @property + def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... + def __init__( + self, + *, + request_id: builtins.str = ..., + payload: global___Payload | None = ..., + error: builtins.str = ..., + metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"]) -> None: ... + +global___RpcResponse = RpcResponse + +@typing.final +class RegisterAgentTypeRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TYPE_FIELD_NUMBER: builtins.int + type: builtins.str + def __init__( + self, + *, + type: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["type", b"type"]) -> None: ... + +global___RegisterAgentTypeRequest = RegisterAgentTypeRequest + +@typing.final +class RegisterAgentTypeResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___RegisterAgentTypeResponse = RegisterAgentTypeResponse + +@typing.final +class TypeSubscription(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TOPIC_TYPE_FIELD_NUMBER: builtins.int + AGENT_TYPE_FIELD_NUMBER: builtins.int + topic_type: builtins.str + agent_type: builtins.str + def __init__( + self, + *, + topic_type: builtins.str = ..., + agent_type: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type", b"topic_type"]) -> None: ... + +global___TypeSubscription = TypeSubscription + +@typing.final +class TypePrefixSubscription(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TOPIC_TYPE_PREFIX_FIELD_NUMBER: builtins.int + AGENT_TYPE_FIELD_NUMBER: builtins.int + topic_type_prefix: builtins.str + agent_type: builtins.str + def __init__( + self, + *, + topic_type_prefix: builtins.str = ..., + agent_type: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type_prefix", b"topic_type_prefix"]) -> None: ... + +global___TypePrefixSubscription = TypePrefixSubscription + +@typing.final +class Subscription(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ID_FIELD_NUMBER: builtins.int + TYPESUBSCRIPTION_FIELD_NUMBER: builtins.int + TYPEPREFIXSUBSCRIPTION_FIELD_NUMBER: builtins.int + id: builtins.str + @property + def typeSubscription(self) -> global___TypeSubscription: ... + @property + def typePrefixSubscription(self) -> global___TypePrefixSubscription: ... + def __init__( + self, + *, + id: builtins.str = ..., + typeSubscription: global___TypeSubscription | None = ..., + typePrefixSubscription: global___TypePrefixSubscription | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["id", b"id", "subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["subscription", b"subscription"]) -> typing.Literal["typeSubscription", "typePrefixSubscription"] | None: ... + +global___Subscription = Subscription + +@typing.final +class AddSubscriptionRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SUBSCRIPTION_FIELD_NUMBER: builtins.int + @property + def subscription(self) -> global___Subscription: ... + def __init__( + self, + *, + subscription: global___Subscription | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["subscription", b"subscription"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["subscription", b"subscription"]) -> None: ... + +global___AddSubscriptionRequest = AddSubscriptionRequest + +@typing.final +class AddSubscriptionResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___AddSubscriptionResponse = AddSubscriptionResponse + +@typing.final +class RemoveSubscriptionRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ID_FIELD_NUMBER: builtins.int + id: builtins.str + def __init__( + self, + *, + id: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["id", b"id"]) -> None: ... + +global___RemoveSubscriptionRequest = RemoveSubscriptionRequest + +@typing.final +class RemoveSubscriptionResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___RemoveSubscriptionResponse = RemoveSubscriptionResponse + +@typing.final +class GetSubscriptionsRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___GetSubscriptionsRequest = GetSubscriptionsRequest + +@typing.final +class GetSubscriptionsResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SUBSCRIPTIONS_FIELD_NUMBER: builtins.int + @property + def subscriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Subscription]: ... + def __init__( + self, + *, + subscriptions: collections.abc.Iterable[global___Subscription] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["subscriptions", b"subscriptions"]) -> None: ... + +global___GetSubscriptionsResponse = GetSubscriptionsResponse + +@typing.final +class Message(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + REQUEST_FIELD_NUMBER: builtins.int + RESPONSE_FIELD_NUMBER: builtins.int + CLOUDEVENT_FIELD_NUMBER: builtins.int + @property + def request(self) -> global___RpcRequest: ... + @property + def response(self) -> global___RpcResponse: ... + @property + def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ... + def __init__( + self, + *, + request: global___RpcRequest | None = ..., + response: global___RpcResponse | None = ..., + cloudEvent: cloudevent_pb2.CloudEvent | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent"] | None: ... + +global___Message = Message + +@typing.final +class SaveStateRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + AGENTID_FIELD_NUMBER: builtins.int + @property + def agentId(self) -> global___AgentId: ... + def __init__( + self, + *, + agentId: global___AgentId | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["agentId", b"agentId"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["agentId", b"agentId"]) -> None: ... + +global___SaveStateRequest = SaveStateRequest + +@typing.final +class SaveStateResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATE_FIELD_NUMBER: builtins.int + ERROR_FIELD_NUMBER: builtins.int + state: builtins.str + error: builtins.str + def __init__( + self, + *, + state: builtins.str = ..., + error: builtins.str | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "state", b"state"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ... + +global___SaveStateResponse = SaveStateResponse + +@typing.final +class LoadStateRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + AGENTID_FIELD_NUMBER: builtins.int + STATE_FIELD_NUMBER: builtins.int + state: builtins.str + @property + def agentId(self) -> global___AgentId: ... + def __init__( + self, + *, + agentId: global___AgentId | None = ..., + state: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["agentId", b"agentId"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["agentId", b"agentId", "state", b"state"]) -> None: ... + +global___LoadStateRequest = LoadStateRequest + +@typing.final +class LoadStateResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ERROR_FIELD_NUMBER: builtins.int + error: builtins.str + def __init__( + self, + *, + error: builtins.str | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ... + +global___LoadStateResponse = LoadStateResponse + +@typing.final +class ControlMessage(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + RPC_ID_FIELD_NUMBER: builtins.int + DESTINATION_FIELD_NUMBER: builtins.int + RESPOND_TO_FIELD_NUMBER: builtins.int + RPCMESSAGE_FIELD_NUMBER: builtins.int + rpc_id: builtins.str + """A response message should have the same id as the request message""" + destination: builtins.str + """This is either: + agentid=AGENT_ID + clientid=CLIENT_ID + """ + respond_to: builtins.str + """This is either: + agentid=AGENT_ID + clientid=CLIENT_ID + Empty string means the message is a response + """ + @property + def rpcMessage(self) -> google.protobuf.any_pb2.Any: + """One of: + SaveStateRequest saveStateRequest = 2; + SaveStateResponse saveStateResponse = 3; + LoadStateRequest loadStateRequest = 4; + LoadStateResponse loadStateResponse = 5; + """ + + def __init__( + self, + *, + rpc_id: builtins.str = ..., + destination: builtins.str = ..., + respond_to: builtins.str | None = ..., + rpcMessage: google.protobuf.any_pb2.Any | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_respond_to", b"_respond_to", "respond_to", b"respond_to", "rpcMessage", b"rpcMessage"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_respond_to", b"_respond_to", "destination", b"destination", "respond_to", b"respond_to", "rpcMessage", b"rpcMessage", "rpc_id", b"rpc_id"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_respond_to", b"_respond_to"]) -> typing.Literal["respond_to"] | None: ... + +global___ControlMessage = ControlMessage diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2_grpc.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2_grpc.py new file mode 100644 index 0000000..4a86f17 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2_grpc.py @@ -0,0 +1,312 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from . import agent_worker_pb2 as agent__worker__pb2 + +GRPC_GENERATED_VERSION = '1.70.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in agent_worker_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class AgentRpcStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.OpenChannel = channel.stream_stream( + '/agents.AgentRpc/OpenChannel', + request_serializer=agent__worker__pb2.Message.SerializeToString, + response_deserializer=agent__worker__pb2.Message.FromString, + _registered_method=True) + self.OpenControlChannel = channel.stream_stream( + '/agents.AgentRpc/OpenControlChannel', + request_serializer=agent__worker__pb2.ControlMessage.SerializeToString, + response_deserializer=agent__worker__pb2.ControlMessage.FromString, + _registered_method=True) + self.RegisterAgent = channel.unary_unary( + '/agents.AgentRpc/RegisterAgent', + request_serializer=agent__worker__pb2.RegisterAgentTypeRequest.SerializeToString, + response_deserializer=agent__worker__pb2.RegisterAgentTypeResponse.FromString, + _registered_method=True) + self.AddSubscription = channel.unary_unary( + '/agents.AgentRpc/AddSubscription', + request_serializer=agent__worker__pb2.AddSubscriptionRequest.SerializeToString, + response_deserializer=agent__worker__pb2.AddSubscriptionResponse.FromString, + _registered_method=True) + self.RemoveSubscription = channel.unary_unary( + '/agents.AgentRpc/RemoveSubscription', + request_serializer=agent__worker__pb2.RemoveSubscriptionRequest.SerializeToString, + response_deserializer=agent__worker__pb2.RemoveSubscriptionResponse.FromString, + _registered_method=True) + self.GetSubscriptions = channel.unary_unary( + '/agents.AgentRpc/GetSubscriptions', + request_serializer=agent__worker__pb2.GetSubscriptionsRequest.SerializeToString, + response_deserializer=agent__worker__pb2.GetSubscriptionsResponse.FromString, + _registered_method=True) + + +class AgentRpcServicer(object): + """Missing associated documentation comment in .proto file.""" + + def OpenChannel(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def OpenControlChannel(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def RegisterAgent(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def AddSubscription(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def RemoveSubscription(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetSubscriptions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_AgentRpcServicer_to_server(servicer, server): + rpc_method_handlers = { + 'OpenChannel': grpc.stream_stream_rpc_method_handler( + servicer.OpenChannel, + request_deserializer=agent__worker__pb2.Message.FromString, + response_serializer=agent__worker__pb2.Message.SerializeToString, + ), + 'OpenControlChannel': grpc.stream_stream_rpc_method_handler( + servicer.OpenControlChannel, + request_deserializer=agent__worker__pb2.ControlMessage.FromString, + response_serializer=agent__worker__pb2.ControlMessage.SerializeToString, + ), + 'RegisterAgent': grpc.unary_unary_rpc_method_handler( + servicer.RegisterAgent, + request_deserializer=agent__worker__pb2.RegisterAgentTypeRequest.FromString, + response_serializer=agent__worker__pb2.RegisterAgentTypeResponse.SerializeToString, + ), + 'AddSubscription': grpc.unary_unary_rpc_method_handler( + servicer.AddSubscription, + request_deserializer=agent__worker__pb2.AddSubscriptionRequest.FromString, + response_serializer=agent__worker__pb2.AddSubscriptionResponse.SerializeToString, + ), + 'RemoveSubscription': grpc.unary_unary_rpc_method_handler( + servicer.RemoveSubscription, + request_deserializer=agent__worker__pb2.RemoveSubscriptionRequest.FromString, + response_serializer=agent__worker__pb2.RemoveSubscriptionResponse.SerializeToString, + ), + 'GetSubscriptions': grpc.unary_unary_rpc_method_handler( + servicer.GetSubscriptions, + request_deserializer=agent__worker__pb2.GetSubscriptionsRequest.FromString, + response_serializer=agent__worker__pb2.GetSubscriptionsResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'agents.AgentRpc', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('agents.AgentRpc', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class AgentRpc(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def OpenChannel(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream( + request_iterator, + target, + '/agents.AgentRpc/OpenChannel', + agent__worker__pb2.Message.SerializeToString, + agent__worker__pb2.Message.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def OpenControlChannel(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream( + request_iterator, + target, + '/agents.AgentRpc/OpenControlChannel', + agent__worker__pb2.ControlMessage.SerializeToString, + agent__worker__pb2.ControlMessage.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def RegisterAgent(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/agents.AgentRpc/RegisterAgent', + agent__worker__pb2.RegisterAgentTypeRequest.SerializeToString, + agent__worker__pb2.RegisterAgentTypeResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def AddSubscription(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/agents.AgentRpc/AddSubscription', + agent__worker__pb2.AddSubscriptionRequest.SerializeToString, + agent__worker__pb2.AddSubscriptionResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def RemoveSubscription(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/agents.AgentRpc/RemoveSubscription', + agent__worker__pb2.RemoveSubscriptionRequest.SerializeToString, + agent__worker__pb2.RemoveSubscriptionResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetSubscriptions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/agents.AgentRpc/GetSubscriptions', + agent__worker__pb2.GetSubscriptionsRequest.SerializeToString, + agent__worker__pb2.GetSubscriptionsResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2_grpc.pyi b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2_grpc.pyi new file mode 100644 index 0000000..cc43118 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/agent_worker_pb2_grpc.pyi @@ -0,0 +1,126 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" + +import abc +from . import agent_worker_pb2 +import collections.abc +import grpc +import grpc.aio +import typing + +_T = typing.TypeVar("_T") + +class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ... + +class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore[misc, type-arg] + ... + +class AgentRpcStub: + def __init__(self, channel: typing.Union[grpc.Channel, grpc.aio.Channel]) -> None: ... + OpenChannel: grpc.StreamStreamMultiCallable[ + agent_worker_pb2.Message, + agent_worker_pb2.Message, + ] + + OpenControlChannel: grpc.StreamStreamMultiCallable[ + agent_worker_pb2.ControlMessage, + agent_worker_pb2.ControlMessage, + ] + + RegisterAgent: grpc.UnaryUnaryMultiCallable[ + agent_worker_pb2.RegisterAgentTypeRequest, + agent_worker_pb2.RegisterAgentTypeResponse, + ] + + AddSubscription: grpc.UnaryUnaryMultiCallable[ + agent_worker_pb2.AddSubscriptionRequest, + agent_worker_pb2.AddSubscriptionResponse, + ] + + RemoveSubscription: grpc.UnaryUnaryMultiCallable[ + agent_worker_pb2.RemoveSubscriptionRequest, + agent_worker_pb2.RemoveSubscriptionResponse, + ] + + GetSubscriptions: grpc.UnaryUnaryMultiCallable[ + agent_worker_pb2.GetSubscriptionsRequest, + agent_worker_pb2.GetSubscriptionsResponse, + ] + +class AgentRpcAsyncStub: + OpenChannel: grpc.aio.StreamStreamMultiCallable[ + agent_worker_pb2.Message, + agent_worker_pb2.Message, + ] + + OpenControlChannel: grpc.aio.StreamStreamMultiCallable[ + agent_worker_pb2.ControlMessage, + agent_worker_pb2.ControlMessage, + ] + + RegisterAgent: grpc.aio.UnaryUnaryMultiCallable[ + agent_worker_pb2.RegisterAgentTypeRequest, + agent_worker_pb2.RegisterAgentTypeResponse, + ] + + AddSubscription: grpc.aio.UnaryUnaryMultiCallable[ + agent_worker_pb2.AddSubscriptionRequest, + agent_worker_pb2.AddSubscriptionResponse, + ] + + RemoveSubscription: grpc.aio.UnaryUnaryMultiCallable[ + agent_worker_pb2.RemoveSubscriptionRequest, + agent_worker_pb2.RemoveSubscriptionResponse, + ] + + GetSubscriptions: grpc.aio.UnaryUnaryMultiCallable[ + agent_worker_pb2.GetSubscriptionsRequest, + agent_worker_pb2.GetSubscriptionsResponse, + ] + +class AgentRpcServicer(metaclass=abc.ABCMeta): + @abc.abstractmethod + def OpenChannel( + self, + request_iterator: _MaybeAsyncIterator[agent_worker_pb2.Message], + context: _ServicerContext, + ) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.Message], collections.abc.AsyncIterator[agent_worker_pb2.Message]]: ... + + @abc.abstractmethod + def OpenControlChannel( + self, + request_iterator: _MaybeAsyncIterator[agent_worker_pb2.ControlMessage], + context: _ServicerContext, + ) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.ControlMessage], collections.abc.AsyncIterator[agent_worker_pb2.ControlMessage]]: ... + + @abc.abstractmethod + def RegisterAgent( + self, + request: agent_worker_pb2.RegisterAgentTypeRequest, + context: _ServicerContext, + ) -> typing.Union[agent_worker_pb2.RegisterAgentTypeResponse, collections.abc.Awaitable[agent_worker_pb2.RegisterAgentTypeResponse]]: ... + + @abc.abstractmethod + def AddSubscription( + self, + request: agent_worker_pb2.AddSubscriptionRequest, + context: _ServicerContext, + ) -> typing.Union[agent_worker_pb2.AddSubscriptionResponse, collections.abc.Awaitable[agent_worker_pb2.AddSubscriptionResponse]]: ... + + @abc.abstractmethod + def RemoveSubscription( + self, + request: agent_worker_pb2.RemoveSubscriptionRequest, + context: _ServicerContext, + ) -> typing.Union[agent_worker_pb2.RemoveSubscriptionResponse, collections.abc.Awaitable[agent_worker_pb2.RemoveSubscriptionResponse]]: ... + + @abc.abstractmethod + def GetSubscriptions( + self, + request: agent_worker_pb2.GetSubscriptionsRequest, + context: _ServicerContext, + ) -> typing.Union[agent_worker_pb2.GetSubscriptionsResponse, collections.abc.Awaitable[agent_worker_pb2.GetSubscriptionsResponse]]: ... + +def add_AgentRpcServicer_to_server(servicer: AgentRpcServicer, server: typing.Union[grpc.Server, grpc.aio.Server]) -> None: ... diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2.py new file mode 100644 index 0000000..0872d75 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: cloudevent.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'cloudevent.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63loudevent.proto\x12\x11io.cloudevents.v1\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xb0\x04\n\nCloudEvent\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\x12\x14\n\x0cspec_version\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x41\n\nattributes\x18\x05 \x03(\x0b\x32-.io.cloudevents.v1.CloudEvent.AttributesEntry\x12\x15\n\x0b\x62inary_data\x18\x06 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x07 \x01(\tH\x00\x12*\n\nproto_data\x18\x08 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x1ai\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x45\n\x05value\x18\x02 \x01(\x0b\x32\x36.io.cloudevents.v1.CloudEvent.CloudEventAttributeValue:\x02\x38\x01\x1a\xd3\x01\n\x18\x43loudEventAttributeValue\x12\x14\n\nce_boolean\x18\x01 \x01(\x08H\x00\x12\x14\n\nce_integer\x18\x02 \x01(\x05H\x00\x12\x13\n\tce_string\x18\x03 \x01(\tH\x00\x12\x12\n\x08\x63\x65_bytes\x18\x04 \x01(\x0cH\x00\x12\x10\n\x06\x63\x65_uri\x18\x05 \x01(\tH\x00\x12\x14\n\nce_uri_ref\x18\x06 \x01(\tH\x00\x12\x32\n\x0c\x63\x65_timestamp\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x00\x42\x06\n\x04\x61ttrB\x06\n\x04\x64\x61taB\x1e\xaa\x02\x1bMicrosoft.AutoGen.Contractsb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'cloudevent_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\252\002\033Microsoft.AutoGen.Contracts' + _globals['_CLOUDEVENT_ATTRIBUTESENTRY']._loaded_options = None + _globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_options = b'8\001' + _globals['_CLOUDEVENT']._serialized_start=100 + _globals['_CLOUDEVENT']._serialized_end=660 + _globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_start=333 + _globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_end=438 + _globals['_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE']._serialized_start=441 + _globals['_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE']._serialized_end=652 +# @@protoc_insertion_point(module_scope) diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2.pyi b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2.pyi new file mode 100644 index 0000000..bbdb162 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2.pyi @@ -0,0 +1,125 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +* +CloudEvent Protobuf Format + +- Required context attributes are explicitly represented. +- Optional and Extension context attributes are carried in a map structure. +- Data may be represented as binary, text, or protobuf messages. +""" + +import builtins +import collections.abc +import google.protobuf.any_pb2 +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import google.protobuf.timestamp_pb2 +import typing + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing.final +class CloudEvent(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class AttributesEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + @property + def value(self) -> global___CloudEvent.CloudEventAttributeValue: ... + def __init__( + self, + *, + key: builtins.str = ..., + value: global___CloudEvent.CloudEventAttributeValue | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + + @typing.final + class CloudEventAttributeValue(google.protobuf.message.Message): + """* + The CloudEvent specification defines + seven attribute value types... + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CE_BOOLEAN_FIELD_NUMBER: builtins.int + CE_INTEGER_FIELD_NUMBER: builtins.int + CE_STRING_FIELD_NUMBER: builtins.int + CE_BYTES_FIELD_NUMBER: builtins.int + CE_URI_FIELD_NUMBER: builtins.int + CE_URI_REF_FIELD_NUMBER: builtins.int + CE_TIMESTAMP_FIELD_NUMBER: builtins.int + ce_boolean: builtins.bool + ce_integer: builtins.int + ce_string: builtins.str + ce_bytes: builtins.bytes + ce_uri: builtins.str + ce_uri_ref: builtins.str + @property + def ce_timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: ... + def __init__( + self, + *, + ce_boolean: builtins.bool = ..., + ce_integer: builtins.int = ..., + ce_string: builtins.str = ..., + ce_bytes: builtins.bytes = ..., + ce_uri: builtins.str = ..., + ce_uri_ref: builtins.str = ..., + ce_timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["attr", b"attr", "ce_boolean", b"ce_boolean", "ce_bytes", b"ce_bytes", "ce_integer", b"ce_integer", "ce_string", b"ce_string", "ce_timestamp", b"ce_timestamp", "ce_uri", b"ce_uri", "ce_uri_ref", b"ce_uri_ref"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["attr", b"attr", "ce_boolean", b"ce_boolean", "ce_bytes", b"ce_bytes", "ce_integer", b"ce_integer", "ce_string", b"ce_string", "ce_timestamp", b"ce_timestamp", "ce_uri", b"ce_uri", "ce_uri_ref", b"ce_uri_ref"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["attr", b"attr"]) -> typing.Literal["ce_boolean", "ce_integer", "ce_string", "ce_bytes", "ce_uri", "ce_uri_ref", "ce_timestamp"] | None: ... + + ID_FIELD_NUMBER: builtins.int + SOURCE_FIELD_NUMBER: builtins.int + SPEC_VERSION_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + ATTRIBUTES_FIELD_NUMBER: builtins.int + BINARY_DATA_FIELD_NUMBER: builtins.int + TEXT_DATA_FIELD_NUMBER: builtins.int + PROTO_DATA_FIELD_NUMBER: builtins.int + id: builtins.str + """-- CloudEvent Context Attributes + + Required Attributes + """ + source: builtins.str + """URI-reference""" + spec_version: builtins.str + type: builtins.str + binary_data: builtins.bytes + text_data: builtins.str + @property + def attributes(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___CloudEvent.CloudEventAttributeValue]: + """Optional & Extension Attributes""" + + @property + def proto_data(self) -> google.protobuf.any_pb2.Any: ... + def __init__( + self, + *, + id: builtins.str = ..., + source: builtins.str = ..., + spec_version: builtins.str = ..., + type: builtins.str = ..., + attributes: collections.abc.Mapping[builtins.str, global___CloudEvent.CloudEventAttributeValue] | None = ..., + binary_data: builtins.bytes = ..., + text_data: builtins.str = ..., + proto_data: google.protobuf.any_pb2.Any | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["binary_data", b"binary_data", "data", b"data", "proto_data", b"proto_data", "text_data", b"text_data"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["attributes", b"attributes", "binary_data", b"binary_data", "data", b"data", "id", b"id", "proto_data", b"proto_data", "source", b"source", "spec_version", b"spec_version", "text_data", b"text_data", "type", b"type"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["data", b"data"]) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ... + +global___CloudEvent = CloudEvent diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2_grpc.py b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2_grpc.py new file mode 100644 index 0000000..f6d836d --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2_grpc.py @@ -0,0 +1,24 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + + +GRPC_GENERATED_VERSION = '1.70.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in cloudevent_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) diff --git a/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2_grpc.pyi b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2_grpc.pyi new file mode 100644 index 0000000..0f50cd8 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/runtimes/grpc/protos/cloudevent_pb2_grpc.pyi @@ -0,0 +1,23 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +* +CloudEvent Protobuf Format + +- Required context attributes are explicitly represented. +- Optional and Extension context attributes are carried in a map structure. +- Data may be represented as binary, text, or protobuf messages. +""" + +import abc +import collections.abc +import grpc +import grpc.aio +import typing + +_T = typing.TypeVar("_T") + +class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ... + +class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore[misc, type-arg] + ... diff --git a/agent_dhal/agentdhal_extensions/teams/__init__.py b/agent_dhal/agentdhal_extensions/teams/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent_dhal/agentdhal_extensions/teams/magentic_one.py b/agent_dhal/agentdhal_extensions/teams/magentic_one.py new file mode 100644 index 0000000..c5c56c9 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/teams/magentic_one.py @@ -0,0 +1,287 @@ +import warnings +from typing import Awaitable, Callable, List, Optional, Union + +from agentdhal_agentchat.agents import ApprovalFuncType, CodeExecutorAgent, UserProxyAgent +from agentdhal_agentchat.base import ChatAgent +from agentdhal_agentchat.teams import MagenticOneGroupChat +from agentdhal_core import CancellationToken +from agentdhal_core.code_executor import CodeExecutor +from agentdhal_core.models import ChatCompletionClient + +from agentdhal_extensions.agents.file_surfer import FileSurfer +from agentdhal_extensions.agents.magentic_one import MagenticOneCoderAgent +from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer +from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor +from agentdhal_extensions.models.openai._openai_client import BaseOpenAIChatCompletionClient + +# Docker imports for default code executor +try: + import docker + from docker.errors import DockerException + + from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor + + _docker_available = True +except ImportError: + docker = None # type: ignore + DockerException = Exception # type: ignore + DockerCommandLineCodeExecutor = None # type: ignore + _docker_available = False + +SyncInputFunc = Callable[[str], str] +AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]] +InputFuncType = Union[SyncInputFunc, AsyncInputFunc] + + +def _is_docker_available() -> bool: + """Check if Docker is available and running.""" + if not _docker_available: + return False + + try: + if docker is not None: + client = docker.from_env() + client.ping() # type: ignore + return True + except DockerException: + return False + + return False + + +def _create_default_code_executor() -> CodeExecutor: + """Create the default code executor, preferring Docker if available.""" + if _is_docker_available() and DockerCommandLineCodeExecutor is not None: + try: + return DockerCommandLineCodeExecutor() + except Exception: + # Fallback to local if Docker fails to initialize + pass + + # Issue warning and use local executor if Docker is not available + warnings.warn( + "Docker is not available or not running. Using LocalCommandLineCodeExecutor instead of the recommended DockerCommandLineCodeExecutor. " + "For security, it is recommended to install Docker and ensure it's running before using MagenticOne. " + "To install Docker, visit: https://docs.docker.com/get-docker/", + UserWarning, + stacklevel=3, + ) + return LocalCommandLineCodeExecutor() + + +class MagenticOne(MagenticOneGroupChat): + """ + MagenticOne is a specialized group chat class that integrates various agents + such as FileSurfer, WebSurfer, Coder, and Executor to solve complex tasks. + To read more about the science behind Magentic-One, see the full blog post: `Magentic-One: A Generalist Multi-Agent System for Solving Complex Tasks `_ and the references below. + + Installation: + + .. code-block:: bash + + pip install "agentdhal-ext[magentic-one]" + + + Args: + client (ChatCompletionClient): The client used for model interactions. + hil_mode (bool): Optional; If set to True, adds the UserProxyAgent to the list of agents. + input_func (InputFuncType | None): Optional; Function to use for user input in human-in-the-loop mode. + code_executor (CodeExecutor | None): Optional; Code executor to use. If None, will use Docker if available, otherwise local executor. + approval_func (ApprovalFuncType | None): Optional; Function to approve code execution before running. If None, code will execute without approval. + + .. warning:: + Using Magentic-One involves interacting with a digital world designed for humans, which carries inherent risks. To minimize these risks, consider the following precautions: + + 1. **Use Containers**: Run all tasks in docker containers to isolate the agents and prevent direct system attacks. + 2. **Virtual Environment**: Use a virtual environment to run the agents and prevent them from accessing sensitive data. + 3. **Monitor Logs**: Closely monitor logs during and after execution to detect and mitigate risky behavior. + 4. **Human Oversight**: Run the examples with a human in the loop to supervise the agents and prevent unintended consequences. + 5. **Limit Access**: Restrict the agents' access to the internet and other resources to prevent unauthorized actions. + 6. **Safeguard Data**: Ensure that the agents do not have access to sensitive data or resources that could be compromised. Do not share sensitive information with the agents. + + Be aware that agents may occasionally attempt risky actions, such as recruiting humans for help or accepting cookie agreements without human involvement. Always ensure agents are monitored and operate within a controlled environment to prevent unintended consequences. Moreover, be cautious that Magentic-One may be susceptible to prompt injection attacks from webpages. + + Architecture: + + Magentic-One is a generalist multi-agent system for solving open-ended web and file-based tasks across a variety of domains. It represents a significant step towards developing agents that can complete tasks that people encounter in their work and personal lives. + + Magentic-One work is based on a multi-agent architecture where a lead Orchestrator agent is responsible for high-level planning, directing other agents, and tracking task progress. The Orchestrator begins by creating a plan to tackle the task, gathering needed facts and educated guesses in a Task Ledger that is maintained. At each step of its plan, the Orchestrator creates a Progress Ledger where it self-reflects on task progress and checks whether the task is completed. If the task is not yet completed, it assigns one of Magentic-One's other agents a subtask to complete. After the assigned agent completes its subtask, the Orchestrator updates the Progress Ledger and continues in this way until the task is complete. If the Orchestrator finds that progress is not being made for enough steps, it can update the Task Ledger and create a new plan. + + Overall, Magentic-One consists of the following agents: + + - Orchestrator: The lead agent responsible for task decomposition and planning, directing other agents in executing subtasks, tracking overall progress, and taking corrective actions as needed. + - WebSurfer: An LLM-based agent proficient in commanding and managing the state of a Chromium-based web browser. It performs actions on the browser and reports on the new state of the web page. + - FileSurfer: An LLM-based agent that commands a markdown-based file preview application to read local files of most types. It can also perform common navigation tasks such as listing the contents of directories and navigating a folder structure. + - Coder: An LLM-based agent specialized in writing code, analyzing information collected from other agents, or creating new artifacts. + - ComputerTerminal: Provides the team with access to a console shell where the Coder's programs can be executed, and where new programming libraries can be installed. + + Together, Magentic-One's agents provide the Orchestrator with the tools and capabilities needed to solve a broad variety of open-ended problems, as well as the ability to autonomously adapt to, and act in, dynamic and ever-changing web and file-system environments. + + Examples: + + .. code-block:: python + + # Autonomously complete a coding task: + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.teams.magentic_one import MagenticOne + from agentdhal_agentchat.ui import Console + + + async def example_usage(): + client = OpenAIChatCompletionClient(model="gpt-4o") + m1 = MagenticOne(client=client) # Uses DockerCommandLineCodeExecutor by default + task = "Write a Python script to fetch data from an API." + result = await Console(m1.run_stream(task=task)) + print(result) + + + if __name__ == "__main__": + asyncio.run(example_usage()) + + + .. code-block:: python + + # Enable human-in-the-loop mode with explicit Docker executor and code approval + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.teams.magentic_one import MagenticOne + from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor + from agentdhal_agentchat.ui import Console + from agentdhal_agentchat.agents import ApprovalRequest, ApprovalResponse + + + def user_input_func(prompt: str) -> str: + \"\"\"Custom input function for user interaction.\"\"\" + return input(prompt) + + + def approval_func(request: ApprovalRequest) -> ApprovalResponse: + \"\"\"Simple approval function that requests user input.\"\"\" + print(f\"Code to execute:\\n{request.code}\") + user_input = input("Do you approve this code execution? (y/n): ").strip().lower() + if user_input == 'y': + return ApprovalResponse(approved=True, reason=\"User approved the code execution\") + else: + return ApprovalResponse(approved=False, reason=\"User denied the code execution\") + + + async def example_usage_hil(): + client = OpenAIChatCompletionClient(model="gpt-4o") + # Explicitly specify Docker code executor for better security + async with DockerCommandLineCodeExecutor() as code_executor: + m1 = MagenticOne( + client=client, + hil_mode=True, + input_func=user_input_func, + code_executor=code_executor, + approval_func=approval_func + ) + task = "Write a Python script to fetch data from an API." + result = await Console(m1.run_stream(task=task)) + print(result) + + + if __name__ == "__main__": + asyncio.run(example_usage_hil()) + + + .. code-block:: python + + # Enable code execution approval without human-in-the-loop mode + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.teams.magentic_one import MagenticOne + from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor + from agentdhal_agentchat.ui import Console + from agentdhal_agentchat.agents import ApprovalRequest, ApprovalResponse + + + def approval_func(request: ApprovalRequest) -> ApprovalResponse: + \"\"\"Simple approval function that requests user input.\"\"\" + print(f\"Code to execute:\\n{request.code}\") + user_input = input("Do you approve this code execution? (y/n): ").strip().lower() + if user_input == 'y': + return ApprovalResponse(approved=True, reason=\"User approved the code execution\") + else: + return ApprovalResponse(approved=False, reason=\"User denied the code execution\") + + + async def example_usage_with_approval(): + client = OpenAIChatCompletionClient(model="gpt-4o") + # Use approval_func for code approval only (hil_mode=False) + async with DockerCommandLineCodeExecutor() as code_executor: + m1 = MagenticOne( + client=client, + hil_mode=False, # No human-in-the-loop for general conversation + code_executor=code_executor, + approval_func=approval_func # But still ask for code execution approval + ) + task = "Write a Python script to fetch data from an API." + result = await Console(m1.run_stream(task=task)) + print(result) + + + if __name__ == "__main__": + asyncio.run(example_usage_with_approval()) + + References: + .. code-block:: bibtex + + @article{fourney2024magentic, + title={Magentic-one: A generalist multi-agent system for solving complex tasks}, + author={Fourney, Adam and Bansal, Gagan and Mozannar, Hussein and Tan, Cheng and Salinas, Eduardo and Niedtner, Friederike and Proebsting, Grace and Bassman, Griffin and Gerrits, Jack and Alber, Jacob and others}, + journal={arXiv preprint arXiv:2411.04468}, + year={2024}, + url={https://arxiv.org/abs/2411.04468} + } + + + """ + + def __init__( + self, + client: ChatCompletionClient, + hil_mode: bool = False, + input_func: InputFuncType | None = None, + code_executor: CodeExecutor | None = None, + approval_func: ApprovalFuncType | None = None, + ): + self.client = client + self._validate_client_capabilities(client) + + if code_executor is None: + warnings.warn( + "Instantiating MagenticOne without a code_executor is deprecated. Provide a code_executor to clear this warning (e.g., code_executor=DockerCommandLineCodeExecutor() ).", + DeprecationWarning, + stacklevel=2, + ) + code_executor = _create_default_code_executor() + + fs = FileSurfer("FileSurfer", model_client=client) + ws = MultimodalWebSurfer("WebSurfer", model_client=client) + coder = MagenticOneCoderAgent("Coder", model_client=client) + + executor = CodeExecutorAgent("ComputerTerminal", code_executor=code_executor, approval_func=approval_func) + + agents: List[ChatAgent] = [fs, ws, coder, executor] + if hil_mode: + user_proxy = UserProxyAgent("User", input_func=input_func) + agents.append(user_proxy) + super().__init__(agents, model_client=client) + + def _validate_client_capabilities(self, client: ChatCompletionClient) -> None: + capabilities = client.model_info + required_capabilities = ["function_calling", "json_output"] + + if not all(capabilities.get(cap) for cap in required_capabilities): + warnings.warn( + "Client capabilities for MagenticOne must include vision, " "function calling, and json output.", + stacklevel=2, + ) + + if not isinstance(client, BaseOpenAIChatCompletionClient): + warnings.warn( + "MagenticOne performs best with OpenAI GPT-4o model either " "through OpenAI or Azure OpenAI.", + stacklevel=2, + ) diff --git a/agent_dhal/agentdhal_extensions/tools/azure/__init__.py b/agent_dhal/agentdhal_extensions/tools/azure/__init__.py new file mode 100644 index 0000000..1285118 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/azure/__init__.py @@ -0,0 +1,19 @@ +from ._ai_search import ( + AzureAISearchTool, + BaseAzureAISearchTool, + SearchQuery, + SearchResult, + SearchResults, + VectorizableTextQuery, +) +from ._config import AzureAISearchConfig + +__all__ = [ + "AzureAISearchTool", + "BaseAzureAISearchTool", + "SearchQuery", + "SearchResult", + "SearchResults", + "AzureAISearchConfig", + "VectorizableTextQuery", +] diff --git a/agent_dhal/agentdhal_extensions/tools/azure/_ai_search.py b/agent_dhal/agentdhal_extensions/tools/azure/_ai_search.py new file mode 100644 index 0000000..47a6724 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/azure/_ai_search.py @@ -0,0 +1,1137 @@ +from __future__ import annotations + +import asyncio +import logging +import time +from abc import ABC, abstractmethod +from contextvars import ContextVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + Optional, + Protocol, + Union, +) + +from agentdhal_core import CancellationToken, Component +from agentdhal_core.tools import BaseTool, ToolSchema +from pydantic import BaseModel, Field + +from azure.core.credentials import AzureKeyCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.exceptions import HttpResponseError, ResourceNotFoundError +from azure.search.documents.aio import SearchClient + +from ._config import ( + DEFAULT_API_VERSION, + AzureAISearchConfig, +) + +SearchDocument = Dict[str, Any] +MetadataDict = Dict[str, Any] +ContentDict = Dict[str, Any] + +if TYPE_CHECKING: + from azure.search.documents.aio import AsyncSearchItemPaged + + SearchResultsIterable = AsyncSearchItemPaged[SearchDocument] +else: + SearchResultsIterable = Any + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from azure.search.documents.models import ( + VectorizableTextQuery, + VectorizedQuery, + VectorQuery, + ) + +try: + from azure.search.documents.models import VectorizableTextQuery, VectorizedQuery, VectorQuery + + has_azure_search = True +except ImportError: + has_azure_search = False + logger.error( + "The 'azure-search-documents' package is required for this tool but was not found. " + "Please install it with: uv add install azure-search-documents" + ) + + +if TYPE_CHECKING: + from typing import Protocol + + class SearchClientProtocol(Protocol): + async def search(self, **kwargs: Any) -> SearchResultsIterable: ... + async def close(self) -> None: ... +else: + SearchClientProtocol = Any + +__all__ = [ + "AzureAISearchTool", + "BaseAzureAISearchTool", + "SearchQuery", + "SearchResults", + "SearchResult", + "VectorizableTextQuery", + "VectorizedQuery", + "VectorQuery", +] +logger = logging.getLogger(__name__) + + +class SearchQuery(BaseModel): + """Search query parameters. + + This simplified interface only requires a search query string. + All other parameters (top, filters, vector fields, etc.) are specified during tool creation + rather than at query time, making it easier for language models to generate structured output. + + Args: + query (str): The search query text. + """ + + query: str = Field(description="Search query text") + + +class SearchResult(BaseModel): + """Search result. + + Args: + score (float): The search score. + content (ContentDict): The document content. + metadata (MetadataDict): Additional metadata about the document. + """ + + score: float = Field(description="The search score") + content: ContentDict = Field(description="The document content") + metadata: MetadataDict = Field(description="Additional metadata about the document") + + +class SearchResults(BaseModel): + """Container for search results. + + Args: + results (List[SearchResult]): List of search results. + """ + + results: List[SearchResult] = Field(description="List of search results") + + +class EmbeddingProvider(Protocol): + """Protocol defining the interface for embedding generation.""" + + async def _get_embedding(self, query: str) -> List[float]: + """Generate embedding vector for the query text.""" + ... + + +class EmbeddingProviderMixin: + """Mixin class providing embedding generation functionality.""" + + search_config: AzureAISearchConfig + + async def _get_embedding(self, query: str) -> List[float]: + """Generate embedding vector for the query text.""" + if not hasattr(self, "search_config"): + raise ValueError("Host class must have a search_config attribute") + + search_config = self.search_config + embedding_provider = getattr(search_config, "embedding_provider", None) + embedding_model = getattr(search_config, "embedding_model", None) + + if not embedding_provider or not embedding_model: + raise ValueError( + "Client-side embedding is not configured. `embedding_provider` and `embedding_model` must be set." + ) from None + + if embedding_provider.lower() == "azure_openai": + try: + from openai import AsyncAzureOpenAI + + from azure.identity import DefaultAzureCredential + except ImportError: + raise ImportError( + "Azure OpenAI SDK is required for client-side embedding generation. " + "Please install it with: uv add openai azure-identity" + ) from None + + api_key = getattr(search_config, "openai_api_key", None) + api_version = getattr(search_config, "openai_api_version", "2023-11-01") + endpoint = getattr(search_config, "openai_endpoint", None) + + if not endpoint: + raise ValueError( + "Azure OpenAI endpoint (`openai_endpoint`) must be provided for client-side Azure OpenAI embeddings." + ) from None + + if api_key: + azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=endpoint) + else: + + def get_token() -> str: + credential = DefaultAzureCredential() + token = credential.get_token("https://cognitiveservices.azure.com/.default") + if not token or not token.token: + raise ValueError("Failed to acquire token using DefaultAzureCredential for Azure OpenAI.") + return token.token + + azure_client = AsyncAzureOpenAI( + azure_ad_token_provider=get_token, api_version=api_version, azure_endpoint=endpoint + ) + + try: + response = await azure_client.embeddings.create(model=embedding_model, input=query) + return response.data[0].embedding + except Exception as e: + raise ValueError(f"Failed to generate embeddings with Azure OpenAI: {str(e)}") from e + + elif embedding_provider.lower() == "openai": + try: + from openai import AsyncOpenAI + except ImportError: + raise ImportError( + "OpenAI SDK is required for client-side embedding generation. " + "Please install it with: uv add openai" + ) from None + + api_key = getattr(search_config, "openai_api_key", None) + openai_client = AsyncOpenAI(api_key=api_key) + + try: + response = await openai_client.embeddings.create(model=embedding_model, input=query) + return response.data[0].embedding + except Exception as e: + raise ValueError(f"Failed to generate embeddings with OpenAI: {str(e)}") from e + else: + raise ValueError( + f"Unsupported client-side embedding provider: {embedding_provider}. " + "Currently supported providers are 'azure_openai' and 'openai'." + ) + + +class BaseAzureAISearchTool( + BaseTool[SearchQuery, SearchResults], Component[AzureAISearchConfig], EmbeddingProvider, ABC +): + """Abstract base class for Azure AI Search tools. + + This class defines the common interface and functionality for all Azure AI Search tools. + It handles configuration management, client initialization, and the abstract methods + that subclasses must implement. + + Attributes: + search_config: Configuration parameters for the search service. + + Note: + This is an abstract base class and should not be instantiated directly. + Use concrete implementations or the factory methods in AzureAISearchTool. + """ + + component_config_schema = AzureAISearchConfig + component_provider_override = "agentdhal_extensions.tools.azure.BaseAzureAISearchTool" + + def __init__( + self, + name: str, + endpoint: str, + index_name: str, + credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]], + description: Optional[str] = None, + api_version: str = DEFAULT_API_VERSION, + query_type: Literal["simple", "full", "semantic", "vector"] = "simple", + search_fields: Optional[List[str]] = None, + select_fields: Optional[List[str]] = None, + vector_fields: Optional[List[str]] = None, + top: Optional[int] = None, + filter: Optional[str] = None, + semantic_config_name: Optional[str] = None, + enable_caching: bool = False, + cache_ttl_seconds: int = 300, + embedding_provider: Optional[str] = None, + embedding_model: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api_version: Optional[str] = None, + openai_endpoint: Optional[str] = None, + ): + """Initialize the Azure AI Search tool. + + Args: + name (str): The name of this tool instance + endpoint (str): The full URL of your Azure AI Search service + index_name (str): Name of the search index to query + credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Azure credential for authentication + description (Optional[str]): Optional description explaining the tool's purpose + api_version (Optional[str]): Azure AI Search API version to use + query_type (Literal["simple", "full", "semantic", "vector"]): Type of search to perform + search_fields (Optional[List[str]]): Fields to search within documents + select_fields (Optional[List[str]]): Fields to return in search results + vector_fields (Optional[List[str]]): Fields to use for vector search + top (Optional[int]): Maximum number of results to return + filter (Optional[str]): OData filter expression to refine search results + semantic_config_name (Optional[str]): Semantic configuration name for enhanced results + enable_caching (bool): Whether to cache search results + cache_ttl_seconds (int): How long to cache results in seconds + embedding_provider (Optional[str]): Name of embedding provider for client-side embeddings + embedding_model (Optional[str]): Model name for client-side embeddings + openai_api_key (Optional[str]): API key for OpenAI/Azure OpenAI embeddings + openai_api_version (Optional[str]): API version for Azure OpenAI embeddings + openai_endpoint (Optional[str]): Endpoint URL for Azure OpenAI embeddings + """ + if not has_azure_search: + raise ImportError( + "Azure Search SDK is required but not installed. " + "Please install it with: pip install azure-search-documents>=11.4.0" + ) + + if description is None: + description = ( + f"Search for information in the {index_name} index using Azure AI Search. " + f"Supports full-text search with optional filters and semantic capabilities." + ) + + super().__init__( + args_type=SearchQuery, + return_type=SearchResults, + name=name, + description=description, + ) + + processed_credential = self._process_credential(credential) + + self.search_config: AzureAISearchConfig = AzureAISearchConfig( + name=name, + description=description, + endpoint=endpoint, + index_name=index_name, + credential=processed_credential, + api_version=api_version, + query_type=query_type, + search_fields=search_fields, + select_fields=select_fields, + vector_fields=vector_fields, + top=top, + filter=filter, + semantic_config_name=semantic_config_name, + enable_caching=enable_caching, + cache_ttl_seconds=cache_ttl_seconds, + embedding_provider=embedding_provider, + embedding_model=embedding_model, + openai_api_key=openai_api_key, + openai_api_version=openai_api_version, + openai_endpoint=openai_endpoint, + ) + + self._endpoint = endpoint + self._index_name = index_name + self._credential = processed_credential + self._api_version = api_version + + self._client: Optional[SearchClient] = None + self._cache: Dict[str, Dict[str, Any]] = {} + + if self.search_config.api_version == "2023-11-01" and self.search_config.vector_fields: + warning_message = ( + f"When explicitly setting api_version='{self.search_config.api_version}' for vector search: " + f"If client-side embedding is NOT configured (e.g., `embedding_model` is not set), " + f"this tool defaults to service-side vectorization (VectorizableTextQuery), which may fail or have limitations with this API version. " + f"If client-side embedding IS configured, the tool will use VectorizedQuery, which is generally compatible. " + f"For robust vector search, consider omitting api_version (recommended to use SDK default) or use a newer API version." + ) + logger.warning(warning_message) + + async def close(self) -> None: + """Explicitly close the Azure SearchClient if needed (for cleanup).""" + if self._client is not None: + try: + await self._client.close() + except Exception: + pass + finally: + self._client = None + + def _process_credential( + self, credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]] + ) -> Union[AzureKeyCredential, AsyncTokenCredential]: + """Process credential to ensure it's the correct type for async SearchClient. + + Converts dictionary credentials with 'api_key' to AzureKeyCredential objects. + + Args: + credential: The credential in either object or dictionary form + + Returns: + A properly formatted credential object + + Raises: + ValueError: If the credential dictionary doesn't contain an 'api_key' + TypeError: If the credential is not of a supported type + """ + if isinstance(credential, dict): + if "api_key" in credential: + return AzureKeyCredential(credential["api_key"]) + raise ValueError("If credential is a dict, it must contain an 'api_key' key") + + if isinstance(credential, (AzureKeyCredential, AsyncTokenCredential)): + return credential + + raise TypeError("Credential must be AzureKeyCredential, AsyncTokenCredential, or a valid dict") + + async def _get_client(self) -> SearchClient: + """Get the search client for the configured index. + + Returns: + SearchClient: Initialized search client + + Raises: + ValueError: If index doesn't exist or authentication fails + """ + if self._client is not None: + return self._client + + try: + self._client = SearchClient( + endpoint=self.search_config.endpoint, + index_name=self.search_config.index_name, + credential=self.search_config.credential, + api_version=self.search_config.api_version, + ) + return self._client + except ResourceNotFoundError as e: + raise ValueError(f"Index '{self.search_config.index_name}' not found in Azure AI Search service.") from e + except HttpResponseError as e: + if e.status_code == 401: + raise ValueError("Authentication failed. Please check your credentials.") from e + elif e.status_code == 403: + raise ValueError("Permission denied to access this index.") from e + else: + raise ValueError(f"Error connecting to Azure AI Search: {str(e)}") from e + except Exception as e: + raise ValueError(f"Unexpected error initializing search client: {str(e)}") from e + + async def run( + self, args: Union[str, Dict[str, Any], SearchQuery], cancellation_token: Optional[CancellationToken] = None + ) -> SearchResults: + """Execute a search against the Azure AI Search index. + + Args: + args: Search query text or SearchQuery object + cancellation_token: Optional token to cancel the operation + + Returns: + SearchResults: Container with search results and metadata + + Raises: + ValueError: If the search query is empty or invalid + ValueError: If there is an authentication error or other search issue + asyncio.CancelledError: If the operation is cancelled + """ + if isinstance(args, str): + if not args.strip(): + raise ValueError("Search query cannot be empty") + search_query = SearchQuery(query=args) + elif isinstance(args, dict) and "query" in args: + search_query = SearchQuery(query=args["query"]) + elif isinstance(args, SearchQuery): + search_query = args + else: + raise ValueError("Invalid search query format. Expected string, dict with 'query', or SearchQuery") + + if cancellation_token is not None and cancellation_token.is_cancelled(): + raise asyncio.CancelledError("Operation cancelled") + + cache_key = "" + if self.search_config.enable_caching: + cache_key_parts = [ + search_query.query, + str(self.search_config.top), + self.search_config.query_type, + ",".join(sorted(self.search_config.search_fields or [])), + ",".join(sorted(self.search_config.select_fields or [])), + ",".join(sorted(self.search_config.vector_fields or [])), + str(self.search_config.filter or ""), + str(self.search_config.semantic_config_name or ""), + ] + cache_key = ":".join(filter(None, cache_key_parts)) + if cache_key in self._cache: + cache_entry = self._cache[cache_key] + cache_age = time.time() - cache_entry["timestamp"] + if cache_age < self.search_config.cache_ttl_seconds: + logger.debug(f"Using cached results for query: {search_query.query}") + return SearchResults( + results=[ + SearchResult(score=r.score, content=r.content, metadata=r.metadata) + for r in cache_entry["results"] + ] + ) + + try: + search_kwargs: Dict[str, Any] = {} + + if self.search_config.query_type != "vector": + search_kwargs["search_text"] = search_query.query + search_kwargs["query_type"] = self.search_config.query_type + + if self.search_config.search_fields: + search_kwargs["search_fields"] = self.search_config.search_fields # type: ignore[assignment] + + if self.search_config.query_type == "semantic" and self.search_config.semantic_config_name: + search_kwargs["semantic_configuration_name"] = self.search_config.semantic_config_name + + if self.search_config.select_fields: + search_kwargs["select"] = self.search_config.select_fields # type: ignore[assignment] + if self.search_config.filter: + search_kwargs["filter"] = str(self.search_config.filter) + if self.search_config.top is not None: + search_kwargs["top"] = self.search_config.top # type: ignore[assignment] + + if self.search_config.vector_fields and len(self.search_config.vector_fields) > 0: + if not search_query.query: + raise ValueError("Query text cannot be empty for vector search operations") + + use_client_side_embeddings = bool( + self.search_config.embedding_model and self.search_config.embedding_provider + ) + + vector_queries: List[Union[VectorizedQuery, VectorizableTextQuery]] = [] + if use_client_side_embeddings: + from azure.search.documents.models import VectorizedQuery + + embedding_vector: List[float] = await self._get_embedding(search_query.query) + for field_spec in self.search_config.vector_fields: + fields = field_spec if isinstance(field_spec, str) else ",".join(field_spec) + vector_queries.append( + VectorizedQuery( + vector=embedding_vector, + k_nearest_neighbors=self.search_config.top or 5, + fields=fields, + kind="vector", + ) + ) + else: + from azure.search.documents.models import VectorizableTextQuery + + for field in self.search_config.vector_fields: + fields = field if isinstance(field, str) else ",".join(field) + vector_queries.append( + VectorizableTextQuery( # type: ignore + text=search_query.query, + k_nearest_neighbors=self.search_config.top or 5, + fields=fields, + kind="vectorizable", + ) + ) + + search_kwargs["vector_queries"] = vector_queries # type: ignore[assignment] + + if cancellation_token is not None: + dummy_task = asyncio.create_task(asyncio.sleep(60)) + cancellation_token.link_future(dummy_task) + + def is_cancelled() -> bool: + return cancellation_token.is_cancelled() + else: + + def is_cancelled() -> bool: + return False + + client = await self._get_client() + search_results: SearchResultsIterable = await client.search(**search_kwargs) # type: ignore[arg-type] + + results: List[SearchResult] = [] + async for doc in search_results: + if is_cancelled(): + raise asyncio.CancelledError("Operation was cancelled") + + try: + metadata: Dict[str, Any] = {} + content: Dict[str, Any] = {} + + for key, value in doc.items(): + if isinstance(key, str) and key.startswith(("@", "_")): + metadata[key] = value + else: + content[str(key)] = value + + score = float(metadata.get("@search.score", 0.0)) + results.append(SearchResult(score=score, content=content, metadata=metadata)) + except Exception as e: + logger.warning(f"Error processing search document: {e}") + continue + + if self.search_config.enable_caching: + self._cache[cache_key] = {"results": results, "timestamp": time.time()} + + return SearchResults(results=results) + + except asyncio.CancelledError: + raise + except Exception as e: + error_msg = str(e) + if isinstance(e, HttpResponseError): + if hasattr(e, "message") and e.message: + error_msg = e.message + + if "not found" in error_msg.lower(): + raise ValueError(f"Index '{self.search_config.index_name}' not found.") from e + elif "unauthorized" in error_msg.lower() or "401" in error_msg: + raise ValueError(f"Authentication failed: {error_msg}") from e + else: + raise ValueError(f"Error from Azure AI Search: {error_msg}") from e + + def _to_config(self) -> AzureAISearchConfig: + """Convert the current instance to a configuration object.""" + return self.search_config + + @property + def schema(self) -> ToolSchema: + """Return the schema for the tool.""" + return { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Search query text"}}, + "required": ["query"], + "additionalProperties": False, + }, + "strict": True, + } + + def return_value_as_string(self, value: SearchResults) -> str: + """Convert the search results to a string representation.""" + if not value.results: + return "No results found." + + result_strings: List[str] = [] + for i, result in enumerate(value.results, 1): + content_items = [f"{k}: {str(v) if v is not None else 'None'}" for k, v in result.content.items()] + content_str = ", ".join(content_items) + result_strings.append(f"Result {i} (Score: {result.score:.2f}): {content_str}") + + return "\n".join(result_strings) + + @classmethod + def _validate_config( + cls, config_dict: Dict[str, Any], search_type: Literal["full_text", "vector", "hybrid"] + ) -> None: + """Validate configuration for specific search types.""" + credential = config_dict.get("credential") + if isinstance(credential, str): + raise TypeError("Credential must be AzureKeyCredential, AsyncTokenCredential, or a valid dict") + if isinstance(credential, dict) and "api_key" not in credential: + raise ValueError("If credential is a dict, it must contain an 'api_key' key") + + try: + _ = AzureAISearchConfig(**config_dict) + except Exception as e: + raise ValueError(f"Invalid configuration: {str(e)}") from e + + if search_type == "vector": + vector_fields = config_dict.get("vector_fields") + if not vector_fields or len(vector_fields) == 0: + raise ValueError("vector_fields must contain at least one field name for vector search") + + elif search_type == "hybrid": + vector_fields = config_dict.get("vector_fields") + search_fields = config_dict.get("search_fields") + + if not vector_fields or len(vector_fields) == 0: + raise ValueError("vector_fields must contain at least one field name for hybrid search") + + if not search_fields or len(search_fields) == 0: + raise ValueError("search_fields must contain at least one field name for hybrid search") + + @classmethod + @abstractmethod + def _from_config(cls, config: AzureAISearchConfig) -> "BaseAzureAISearchTool": + """Create a tool instance from a configuration object. + + This is an abstract method that must be implemented by subclasses. + """ + if cls is BaseAzureAISearchTool: + raise NotImplementedError( + "BaseAzureAISearchTool is an abstract base class and cannot be instantiated directly. " + "Use a concrete implementation like AzureAISearchTool." + ) + raise NotImplementedError("Subclasses must implement _from_config") + + @abstractmethod + async def _get_embedding(self, query: str) -> List[float]: + """Generate embedding vector for the query text.""" + raise NotImplementedError("Subclasses must implement _get_embedding") + + +_allow_private_constructor = ContextVar("_allow_private_constructor", default=False) + + +class AzureAISearchTool(EmbeddingProviderMixin, BaseAzureAISearchTool): + """Azure AI Search tool for querying Azure search indexes. + + This tool provides a simplified interface for querying Azure AI Search indexes using + various search methods. It's recommended to use the factory methods to create + instances tailored for specific search types: + + 1. **Full-Text Search**: For traditional keyword-based searches, Lucene queries, or + semantically re-ranked results. + - Use `AzureAISearchTool.create_full_text_search()` + - Supports `query_type`: "simple" (keyword), "full" (Lucene), "semantic". + + 2. **Vector Search**: For pure similarity searches based on vector embeddings. + - Use `AzureAISearchTool.create_vector_search()` + + 3. **Hybrid Search**: For combining vector search with full-text or semantic search + to get the benefits of both. + - Use `AzureAISearchTool.create_hybrid_search()` + - The text component can be "simple", "full", or "semantic" via the `query_type` parameter. + + Each factory method configures the tool with appropriate defaults and validations + for the chosen search strategy. + + .. warning:: + If you set `query_type="semantic"`, you must also provide a valid `semantic_config_name`. + This configuration must be set up in your Azure AI Search index beforehand. + """ + + component_provider_override = "agentdhal_extensions.tools.azure.AzureAISearchTool" + + @classmethod + def _from_config(cls, config: AzureAISearchConfig) -> "AzureAISearchTool": + """Create a tool instance from a configuration object. + + Args: + config: The configuration object with tool settings + + Returns: + AzureAISearchTool: An initialized tool instance + """ + token = _allow_private_constructor.set(True) + try: + instance = cls( + name=config.name, + description=config.description or "", + endpoint=config.endpoint, + index_name=config.index_name, + credential=config.credential, + api_version=config.api_version, + query_type=config.query_type, + search_fields=config.search_fields, + select_fields=config.select_fields, + vector_fields=config.vector_fields, + top=config.top, + filter=config.filter, + semantic_config_name=config.semantic_config_name, + enable_caching=config.enable_caching, + cache_ttl_seconds=config.cache_ttl_seconds, + embedding_provider=config.embedding_provider, + embedding_model=config.embedding_model, + openai_api_key=config.openai_api_key, + openai_api_version=config.openai_api_version, + openai_endpoint=config.openai_endpoint, + ) + return instance + finally: + _allow_private_constructor.reset(token) + + @classmethod + def _create_from_params( + cls, config_dict: Dict[str, Any], search_type: Literal["full_text", "vector", "hybrid"] + ) -> "AzureAISearchTool": + """Private helper to create an instance from parameters after validation. + + Args: + config_dict: Dictionary with configuration parameters + search_type: Type of search for validation + + Returns: + Configured AzureAISearchTool instance + """ + cls._validate_config(config_dict, search_type) + + token = _allow_private_constructor.set(True) + try: + return cls(**config_dict) + finally: + _allow_private_constructor.reset(token) + + @classmethod + def create_full_text_search( + cls, + name: str, + endpoint: str, + index_name: str, + credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]], + description: Optional[str] = None, + api_version: Optional[str] = None, + query_type: Literal["simple", "full", "semantic"] = "simple", + search_fields: Optional[List[str]] = None, + select_fields: Optional[List[str]] = None, + top: Optional[int] = 5, + filter: Optional[str] = None, + semantic_config_name: Optional[str] = None, + enable_caching: bool = False, + cache_ttl_seconds: int = 300, + ) -> "AzureAISearchTool": + """Create a tool for traditional text-based searches. + + This factory method creates an AzureAISearchTool optimized for full-text search, + supporting keyword matching, Lucene syntax, and semantic search capabilities. + + Args: + name: The name of this tool instance + endpoint: The full URL of your Azure AI Search service + index_name: Name of the search index to query + credential: Azure credential for authentication (API key or token) + description: Optional description explaining the tool's purpose + api_version: Azure AI Search API version to use + query_type: Type of text search to perform: + + • **simple** : Basic keyword search that matches exact terms and their variations + • **full**: Advanced search using Lucene query syntax for complex queries + • **semantic**: AI-powered search that understands meaning and context, providing enhanced relevance ranking + search_fields: Fields to search within documents + select_fields: Fields to return in search results + top: Maximum number of results to return (default: 5) + filter: OData filter expression to refine search results + semantic_config_name: Semantic configuration name (required for semantic query_type) + enable_caching: Whether to cache search results + cache_ttl_seconds: How long to cache results in seconds + + Returns: + An initialized AzureAISearchTool for full-text search + + Example: + .. code-block:: python + + from azure.core.credentials import AzureKeyCredential + from agentdhal_extensions.tools.azure import AzureAISearchTool + + # Basic keyword search + tool = AzureAISearchTool.create_full_text_search( + name="doc-search", + endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint + index_name="", # Name of your search index + credential=AzureKeyCredential(""), # Your Azure AI Search admin key + query_type="simple", # Enable keyword search + search_fields=["content", "title"], # Required: fields to search within + select_fields=["content", "title", "url"], # Optional: fields to return + top=5, + ) + + # full text (Lucene query) search + full_text_tool = AzureAISearchTool.create_full_text_search( + name="doc-search", + endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint + index_name="", # Name of your search index + credential=AzureKeyCredential(""), # Your Azure AI Search admin key + query_type="full", # Enable Lucene query syntax + search_fields=["content", "title"], # Required: fields to search within + select_fields=["content", "title", "url"], # Optional: fields to return + top=5, + ) + + # Semantic search with re-ranking + # Note: Make sure your index has semantic configuration enabled + semantic_tool = AzureAISearchTool.create_full_text_search( + name="semantic-search", + endpoint="https://your-search.search.windows.net", + index_name="", + credential=AzureKeyCredential(""), + query_type="semantic", # Enable semantic ranking + semantic_config_name="", # Required for semantic search + search_fields=["content", "title"], # Required: fields to search within + select_fields=["content", "title", "url"], # Optional: fields to return + top=5, + ) + + # The search tool can be used with an Agent + # assistant = Agent("assistant", tools=[semantic_tool]) + """ + if query_type == "semantic" and not semantic_config_name: + raise ValueError("semantic_config_name is required when query_type is 'semantic'") + + config_dict = { + "name": name, + "endpoint": endpoint, + "index_name": index_name, + "credential": credential, + "description": description, + "api_version": api_version or DEFAULT_API_VERSION, + "query_type": query_type, + "search_fields": search_fields, + "select_fields": select_fields, + "top": top, + "filter": filter, + "semantic_config_name": semantic_config_name, + "enable_caching": enable_caching, + "cache_ttl_seconds": cache_ttl_seconds, + } + + return cls._create_from_params(config_dict, "full_text") + + @classmethod + def create_vector_search( + cls, + name: str, + endpoint: str, + index_name: str, + credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]], + vector_fields: List[str], + description: Optional[str] = None, + api_version: Optional[str] = None, + select_fields: Optional[List[str]] = None, + top: int = 5, + filter: Optional[str] = None, + enable_caching: bool = False, + cache_ttl_seconds: int = 300, + embedding_provider: Optional[str] = None, + embedding_model: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api_version: Optional[str] = None, + openai_endpoint: Optional[str] = None, + ) -> "AzureAISearchTool": + """Create a tool for pure vector/similarity search. + + This factory method creates an AzureAISearchTool optimized for vector search, + allowing for semantic similarity-based matching using vector embeddings. + + Args: + name: The name of this tool instance + endpoint: The full URL of your Azure AI Search service + index_name: Name of the search index to query + credential: Azure credential for authentication (API key or token) + vector_fields: Fields to use for vector search (required) + description: Optional description explaining the tool's purpose + api_version: Azure AI Search API version to use + select_fields: Fields to return in search results + top: Maximum number of results to return / k in k-NN (default: 5) + filter: OData filter expression to refine search results + enable_caching: Whether to cache search results + cache_ttl_seconds: How long to cache results in seconds + embedding_provider: Provider for client-side embeddings (e.g., 'azure_openai', 'openai') + embedding_model: Model for client-side embeddings (e.g., 'text-embedding-ada-002') + openai_api_key: API key for OpenAI/Azure OpenAI embeddings + openai_api_version: API version for Azure OpenAI embeddings + openai_endpoint: Endpoint URL for Azure OpenAI embeddings + + Returns: + An initialized AzureAISearchTool for vector search + + Raises: + ValueError: If vector_fields is empty + ValueError: If embedding_provider is 'azure_openai' without openai_endpoint + ValueError: If required parameters are missing or invalid + + Example Usage: + .. code-block:: python + + from azure.core.credentials import AzureKeyCredential + from agentdhal_extensions.tools.azure import AzureAISearchTool + + # Vector search with service-side vectorization + tool = AzureAISearchTool.create_vector_search( + name="vector-search", + endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint + index_name="", # Name of your search index + credential=AzureKeyCredential(""), # Your Azure AI Search admin key + vector_fields=["content_vector"], # Your vector field name + select_fields=["content", "title", "url"], # Fields to return in results + top=5, + ) + + # Vector search with Azure OpenAI embeddings + azure_openai_tool = AzureAISearchTool.create_vector_search( + name="azure-openai-vector-search", + endpoint="https://your-search.search.windows.net", + index_name="", + credential=AzureKeyCredential(""), + vector_fields=["content_vector"], + embedding_provider="azure_openai", # Use Azure OpenAI for embeddings + embedding_model="text-embedding-ada-002", # Embedding model to use + openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint + openai_api_key="", # Your Azure OpenAI key + openai_api_version="2024-02-15-preview", # Azure OpenAI API version + select_fields=["content", "title", "url"], # Fields to return in results + top=5, + ) + + # Vector search with OpenAI embeddings + openai_tool = AzureAISearchTool.create_vector_search( + name="openai-vector-search", + endpoint="https://your-search.search.windows.net", + index_name="", + credential=AzureKeyCredential(""), + vector_fields=["content_vector"], + embedding_provider="openai", # Use OpenAI for embeddings + embedding_model="text-embedding-ada-002", # Embedding model to use + openai_api_key="", # Your OpenAI API key + select_fields=["content", "title", "url"], # Fields to return in results + top=5, + ) + + # Use the tool with an Agent + # assistant = Agent("assistant", tools=[azure_openai_tool]) + """ + if embedding_provider == "azure_openai" and not openai_endpoint: + raise ValueError("openai_endpoint is required when embedding_provider is 'azure_openai'") + + config_dict = { + "name": name, + "endpoint": endpoint, + "index_name": index_name, + "credential": credential, + "description": description, + "api_version": api_version or DEFAULT_API_VERSION, + "query_type": "vector", + "select_fields": select_fields, + "vector_fields": vector_fields, + "top": top, + "filter": filter, + "enable_caching": enable_caching, + "cache_ttl_seconds": cache_ttl_seconds, + "embedding_provider": embedding_provider, + "embedding_model": embedding_model, + "openai_api_key": openai_api_key, + "openai_api_version": openai_api_version, + "openai_endpoint": openai_endpoint, + } + + return cls._create_from_params(config_dict, "vector") + + @classmethod + def create_hybrid_search( + cls, + name: str, + endpoint: str, + index_name: str, + credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]], + vector_fields: List[str], + search_fields: List[str], + description: Optional[str] = None, + api_version: Optional[str] = None, + query_type: Literal["simple", "full", "semantic"] = "simple", + select_fields: Optional[List[str]] = None, + top: int = 5, + filter: Optional[str] = None, + semantic_config_name: Optional[str] = None, + enable_caching: bool = False, + cache_ttl_seconds: int = 300, + embedding_provider: Optional[str] = None, + embedding_model: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api_version: Optional[str] = None, + openai_endpoint: Optional[str] = None, + ) -> "AzureAISearchTool": + """Create a tool that combines vector and text search capabilities. + + This factory method creates an AzureAISearchTool configured for hybrid search, + which combines the benefits of vector similarity and traditional text search. + + Args: + name: The name of this tool instance + endpoint: The full URL of your Azure AI Search service + index_name: Name of the search index to query + credential: Azure credential for authentication (API key or token) + vector_fields: Fields to use for vector search (required) + search_fields: Fields to use for text search (required) + description: Optional description explaining the tool's purpose + api_version: Azure AI Search API version to use + query_type: Type of text search to perform: + + • **simple**: Basic keyword search that matches exact terms and their variations + • **full**: Advanced search using Lucene query syntax for complex queries + • **semantic**: AI-powered search that understands meaning and context, providing enhanced relevance ranking + select_fields: Fields to return in search results + top: Maximum number of results to return (default: 5) + filter: OData filter expression to refine search results + semantic_config_name: Semantic configuration name (required if query_type="semantic") + enable_caching: Whether to cache search results + cache_ttl_seconds: How long to cache results in seconds + embedding_provider: Provider for client-side embeddings (e.g., 'azure_openai', 'openai') + embedding_model: Model for client-side embeddings (e.g., 'text-embedding-ada-002') + openai_api_key: API key for OpenAI/Azure OpenAI embeddings + openai_api_version: API version for Azure OpenAI embeddings + openai_endpoint: Endpoint URL for Azure OpenAI embeddings + + Returns: + An initialized AzureAISearchTool for hybrid search + + Raises: + ValueError: If vector_fields or search_fields is empty + ValueError: If query_type is "semantic" without semantic_config_name + ValueError: If embedding_provider is 'azure_openai' without openai_endpoint + ValueError: If required parameters are missing or invalid + + Example: + .. code-block:: python + + from azure.core.credentials import AzureKeyCredential + from agentdhal_extensions.tools.azure import AzureAISearchTool + + # Basic hybrid search with service-side vectorization + tool = AzureAISearchTool.create_hybrid_search( + name="hybrid-search", + endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint + index_name="", # Name of your search index + credential=AzureKeyCredential(""), # Your Azure AI Search admin key + vector_fields=["content_vector"], # Your vector field name + search_fields=["content", "title"], # Your searchable fields + top=5, + ) + + # Hybrid search with semantic ranking and Azure OpenAI embeddings + semantic_tool = AzureAISearchTool.create_hybrid_search( + name="semantic-hybrid-search", + endpoint="https://your-search.search.windows.net", + index_name="", + credential=AzureKeyCredential(""), + vector_fields=["content_vector"], + search_fields=["content", "title"], + query_type="semantic", # Enable semantic ranking + semantic_config_name="", # Your semantic config name + embedding_provider="azure_openai", # Use Azure OpenAI for embeddings + embedding_model="text-embedding-ada-002", # Embedding model to use + openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint + openai_api_key="", # Your Azure OpenAI key + openai_api_version="2024-02-15-preview", # Azure OpenAI API version + select_fields=["content", "title", "url"], # Fields to return in results + filter="language eq 'en'", # Optional OData filter + top=5, + ) + + # The search tool can be used with an Agent + # assistant = Agent("assistant", tools=[semantic_tool]) + """ + if query_type == "semantic" and not semantic_config_name: + raise ValueError("semantic_config_name is required when query_type is 'semantic'") + + if embedding_provider == "azure_openai" and not openai_endpoint: + raise ValueError("openai_endpoint is required when embedding_provider is 'azure_openai'") + + config_dict = { + "name": name, + "endpoint": endpoint, + "index_name": index_name, + "credential": credential, + "description": description, + "api_version": api_version or DEFAULT_API_VERSION, + "query_type": query_type, + "search_fields": search_fields, + "select_fields": select_fields, + "vector_fields": vector_fields, + "top": top, + "filter": filter, + "semantic_config_name": semantic_config_name, + "enable_caching": enable_caching, + "cache_ttl_seconds": cache_ttl_seconds, + "embedding_provider": embedding_provider, + "embedding_model": embedding_model, + "openai_api_key": openai_api_key, + "openai_api_version": openai_api_version, + "openai_endpoint": openai_endpoint, + } + + return cls._create_from_params(config_dict, "hybrid") diff --git a/agent_dhal/agentdhal_extensions/tools/azure/_config.py b/agent_dhal/agentdhal_extensions/tools/azure/_config.py new file mode 100644 index 0000000..b771a8d --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/azure/_config.py @@ -0,0 +1,186 @@ +"""Configuration for Azure AI Search tool. + +This module provides configuration classes for the Azure AI Search tool, including +settings for authentication, search behavior, retry policies, and caching. +""" + +import logging +from typing import ( + List, + Literal, + Optional, + TypeVar, + Union, +) + +from pydantic import BaseModel, Field, field_validator, model_validator + +from azure.core.credentials import AzureKeyCredential +from azure.core.credentials_async import AsyncTokenCredential + +T = TypeVar("T", bound="AzureAISearchConfig") + +logger = logging.getLogger(__name__) + +QueryTypeLiteral = Literal["simple", "full", "semantic", "vector"] +DEFAULT_API_VERSION = "2023-10-01-preview" + + +class AzureAISearchConfig(BaseModel): + """Configuration for Azure AI Search with validation. + + This class defines the configuration parameters for Azure AI Search tools, including + authentication, search behavior, caching, and embedding settings. + + .. note:: + This class requires the ``azure`` extra for the ``autogen-ext`` package. + + .. code-block:: bash + + pip install -U "agentdhal-ext[azure]" + + .. note:: + **Prerequisites:** + + 1. An Azure AI Search service must be created in your Azure subscription. + 2. The search index must be properly configured for your use case: + + - For vector search: Index must have vector fields + - For semantic search: Index must have semantic configuration + - For hybrid search: Both vector fields and text fields must be configured + 3. Required packages: + + - Base functionality: ``azure-search-documents>=11.4.0`` + - For Azure OpenAI embeddings: ``openai azure-identity`` + - For OpenAI embeddings: ``openai`` + + Example Usage: + .. code-block:: python + + from azure.core.credentials import AzureKeyCredential + from agentdhal_extensions.tools.azure import AzureAISearchConfig + + # Basic configuration for full-text search + config = AzureAISearchConfig( + name="doc-search", + endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint + index_name="", # Name of your search index + credential=AzureKeyCredential(""), # Your Azure AI Search admin key + query_type="simple", + search_fields=["content", "title"], # Update with your searchable fields + top=5, + ) + + # Configuration for vector search with Azure OpenAI embeddings + vector_config = AzureAISearchConfig( + name="vector-search", + endpoint="https://your-search.search.windows.net", + index_name="", + credential=AzureKeyCredential(""), + query_type="vector", + vector_fields=["embedding"], # Update with your vector field name + embedding_provider="azure_openai", + embedding_model="text-embedding-ada-002", + openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint + openai_api_key="", # Your Azure OpenAI key + top=5, + ) + + # Configuration for hybrid search with semantic ranking + hybrid_config = AzureAISearchConfig( + name="hybrid-search", + endpoint="https://your-search.search.windows.net", + index_name="", + credential=AzureKeyCredential(""), + query_type="semantic", + semantic_config_name="", # Name of your semantic configuration + search_fields=["content", "title"], # Update with your search fields + vector_fields=["embedding"], # Update with your vector field name + embedding_provider="openai", + embedding_model="text-embedding-ada-002", + openai_api_key="", # Your OpenAI API key + top=5, + ) + """ + + name: str = Field(description="The name of this tool instance") + description: Optional[str] = Field(default=None, description="Description explaining the tool's purpose") + endpoint: str = Field(description="The full URL of your Azure AI Search service") + index_name: str = Field(description="Name of the search index to query") + credential: Union[AzureKeyCredential, AsyncTokenCredential] = Field( + description="Azure credential for authentication (API key or token)" + ) + api_version: str = Field( + default=DEFAULT_API_VERSION, + description=f"Azure AI Search API version to use. Defaults to {DEFAULT_API_VERSION}.", + ) + query_type: QueryTypeLiteral = Field( + default="simple", description="Type of search to perform: simple, full, semantic, or vector" + ) + search_fields: Optional[List[str]] = Field(default=None, description="Fields to search within documents") + select_fields: Optional[List[str]] = Field(default=None, description="Fields to return in search results") + vector_fields: Optional[List[str]] = Field(default=None, description="Fields to use for vector search") + top: Optional[int] = Field( + default=None, description="Maximum number of results to return. For vector searches, acts as k in k-NN." + ) + filter: Optional[str] = Field(default=None, description="OData filter expression to refine search results") + semantic_config_name: Optional[str] = Field( + default=None, description="Semantic configuration name for enhanced results" + ) + + enable_caching: bool = Field(default=False, description="Whether to cache search results") + cache_ttl_seconds: int = Field(default=300, description="How long to cache results in seconds") + + embedding_provider: Optional[str] = Field( + default=None, description="Name of embedding provider for client-side embeddings" + ) + embedding_model: Optional[str] = Field(default=None, description="Model name for client-side embeddings") + openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI/Azure OpenAI embeddings") + openai_api_version: Optional[str] = Field(default=None, description="API version for Azure OpenAI embeddings") + openai_endpoint: Optional[str] = Field(default=None, description="Endpoint URL for Azure OpenAI embeddings") + + model_config = {"arbitrary_types_allowed": True} + + @field_validator("endpoint") + def validate_endpoint(cls, v: str) -> str: + """Validate that the endpoint is a valid URL.""" + if not v.startswith(("http://", "https://")): + raise ValueError("endpoint must be a valid URL starting with http:// or https://") + return v + + @field_validator("query_type") + def normalize_query_type(cls, v: QueryTypeLiteral) -> QueryTypeLiteral: + """Normalize query type to standard values.""" + if not v: + return "simple" + + if isinstance(v, str) and v.lower() == "fulltext": + return "full" + + return v + + @field_validator("top") + def validate_top(cls, v: Optional[int]) -> Optional[int]: + """Ensure top is a positive integer if provided.""" + if v is not None and v <= 0: + raise ValueError("top must be a positive integer") + return v + + @model_validator(mode="after") + def validate_interdependent_fields(self) -> "AzureAISearchConfig": + """Validate interdependent fields after all fields have been parsed.""" + if self.query_type == "semantic" and not self.semantic_config_name: + raise ValueError("semantic_config_name must be provided when query_type is 'semantic'") + + if self.query_type == "vector" and not self.vector_fields: + raise ValueError("vector_fields must be provided for vector search") + + if ( + self.embedding_provider + and self.embedding_provider.lower() == "azure_openai" + and self.embedding_model + and not self.openai_endpoint + ): + raise ValueError("openai_endpoint must be provided for azure_openai embedding provider") + + return self diff --git a/agent_dhal/agentdhal_extensions/tools/code_execution/__init__.py b/agent_dhal/agentdhal_extensions/tools/code_execution/__init__.py new file mode 100644 index 0000000..f58ea00 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/code_execution/__init__.py @@ -0,0 +1,3 @@ +from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool + +__all__ = ["CodeExecutionInput", "CodeExecutionResult", "PythonCodeExecutionTool"] diff --git a/agent_dhal/agentdhal_extensions/tools/code_execution/_code_execution.py b/agent_dhal/agentdhal_extensions/tools/code_execution/_code_execution.py new file mode 100644 index 0000000..412829d --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/code_execution/_code_execution.py @@ -0,0 +1,96 @@ +from agentdhal_core import CancellationToken, Component, ComponentModel +from agentdhal_core.code_executor import CodeBlock, CodeExecutor +from agentdhal_core.tools import BaseTool +from pydantic import BaseModel, Field, model_serializer +from typing_extensions import Self + + +class CodeExecutionInput(BaseModel): + code: str = Field(description="The contents of the Python code block that should be executed") + + +class CodeExecutionResult(BaseModel): + success: bool + output: str + + @model_serializer + def ser_model(self) -> str: + return self.output + + +class PythonCodeExecutionToolConfig(BaseModel): + """Configuration for PythonCodeExecutionTool""" + + executor: ComponentModel + description: str = "Execute Python code blocks." + + +class PythonCodeExecutionTool( + BaseTool[CodeExecutionInput, CodeExecutionResult], Component[PythonCodeExecutionToolConfig] +): + """A tool that executes Python code in a code executor and returns output. + + Example executors: + + * :class:`agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor` + * :class:`agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor` + * :class:`agentdhal_extensions.code_executors.azure.ACADynamicSessionsCodeExecutor` + + Example usage: + + .. code-block:: bash + + pip install -U "agentdhal-agentchat" "agentdhal-ext[openai]" "yfinance" "matplotlib" + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor + from agentdhal_extensions.tools.code_execution import PythonCodeExecutionTool + + + async def main() -> None: + tool = PythonCodeExecutionTool(LocalCommandLineCodeExecutor(work_dir="coding")) + agent = AssistantAgent( + "assistant", OpenAIChatCompletionClient(model="gpt-4o"), tools=[tool], reflect_on_tool_use=True + ) + await Console( + agent.run_stream( + task="Create a plot of MSFT stock prices in 2024 and save it to a file. Use yfinance and matplotlib." + ) + ) + + + asyncio.run(main()) + + + Args: + executor (CodeExecutor): The code executor that will be used to execute the code blocks. + """ + + component_config_schema = PythonCodeExecutionToolConfig + component_provider_override = "agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool" + + def __init__(self, executor: CodeExecutor): + super().__init__(CodeExecutionInput, CodeExecutionResult, "CodeExecutor", "Execute Python code blocks.") + self._executor = executor + + async def run(self, args: CodeExecutionInput, cancellation_token: CancellationToken) -> CodeExecutionResult: + code_blocks = [CodeBlock(code=args.code, language="python")] + result = await self._executor.execute_code_blocks( + code_blocks=code_blocks, cancellation_token=cancellation_token + ) + return CodeExecutionResult(success=result.exit_code == 0, output=result.output) + + def _to_config(self) -> PythonCodeExecutionToolConfig: + """Convert current instance to config object""" + return PythonCodeExecutionToolConfig(executor=self._executor.dump_component()) + + @classmethod + def _from_config(cls, config: PythonCodeExecutionToolConfig) -> Self: + """Create instance from config object""" + executor = CodeExecutor.load_component(config.executor) + return cls(executor=executor) diff --git a/agent_dhal/agentdhal_extensions/tools/graphrag/__init__.py b/agent_dhal/agentdhal_extensions/tools/graphrag/__init__.py new file mode 100644 index 0000000..3d73e50 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/graphrag/__init__.py @@ -0,0 +1,25 @@ +from ._config import ( + GlobalContextConfig, + GlobalDataConfig, + LocalContextConfig, + LocalDataConfig, + MapReduceConfig, + SearchConfig, +) +from ._global_search import GlobalSearchTool, GlobalSearchToolArgs, GlobalSearchToolReturn +from ._local_search import LocalSearchTool, LocalSearchToolArgs, LocalSearchToolReturn + +__all__ = [ + "GlobalSearchTool", + "LocalSearchTool", + "GlobalDataConfig", + "LocalDataConfig", + "GlobalContextConfig", + "GlobalSearchToolArgs", + "GlobalSearchToolReturn", + "LocalContextConfig", + "LocalSearchToolArgs", + "LocalSearchToolReturn", + "MapReduceConfig", + "SearchConfig", +] diff --git a/agent_dhal/agentdhal_extensions/tools/graphrag/_config.py b/agent_dhal/agentdhal_extensions/tools/graphrag/_config.py new file mode 100644 index 0000000..b7df432 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/graphrag/_config.py @@ -0,0 +1,59 @@ +from pydantic import BaseModel + + +class DataConfig(BaseModel): + input_dir: str + entity_table: str = "entities" + entity_embedding_table: str = "entities" + community_table: str = "communities" + community_level: int = 2 + + +class GlobalDataConfig(DataConfig): + community_report_table: str = "community_reports" + + +class LocalDataConfig(DataConfig): + relationship_table: str = "relationships" + text_unit_table: str = "text_units" + + +class ContextConfig(BaseModel): + max_data_tokens: int = 8000 + + +class GlobalContextConfig(ContextConfig): + use_community_summary: bool = False + shuffle_data: bool = True + include_community_rank: bool = True + min_community_rank: int = 0 + community_rank_name: str = "rank" + include_community_weight: bool = True + community_weight_name: str = "occurrence weight" + normalize_community_weight: bool = True + max_data_tokens: int = 12000 + + +class LocalContextConfig(ContextConfig): + text_unit_prop: float = 0.5 + community_prop: float = 0.25 + include_entity_rank: bool = True + rank_description: str = "number of relationships" + include_relationship_weight: bool = True + relationship_ranking_attribute: str = "rank" + + +class MapReduceConfig(BaseModel): + map_max_tokens: int = 1000 + map_temperature: float = 0.0 + reduce_max_tokens: int = 2000 + reduce_temperature: float = 0.0 + allow_general_knowledge: bool = False + json_mode: bool = False + response_type: str = "multiple paragraphs" + + +class SearchConfig(BaseModel): + max_tokens: int = 1500 + temperature: float = 0.0 + response_type: str = "multiple paragraphs" diff --git a/agent_dhal/agentdhal_extensions/tools/graphrag/_global_search.py b/agent_dhal/agentdhal_extensions/tools/graphrag/_global_search.py new file mode 100644 index 0000000..2594ec8 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/graphrag/_global_search.py @@ -0,0 +1,233 @@ +from pathlib import Path + +import pandas as pd +import tiktoken +from agentdhal_core import CancellationToken +from agentdhal_core.tools import BaseTool +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs +from graphrag.config.load_config import load_config +from graphrag.language_model.manager import ModelManager +from graphrag.language_model.protocol import ChatModel +from graphrag.query.indexer_adapters import ( + read_indexer_communities, + read_indexer_entities, + read_indexer_reports, +) +from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext +from graphrag.query.structured_search.global_search.search import GlobalSearch + +from ._config import GlobalContextConfig as ContextConfig +from ._config import GlobalDataConfig as DataConfig +from ._config import MapReduceConfig + +_default_context_config = ContextConfig() +_default_mapreduce_config = MapReduceConfig() + + +class GlobalSearchToolArgs(BaseModel): + query: str = Field(..., description="The user query to perform global search on.") + + +class GlobalSearchToolReturn(BaseModel): + answer: str + + +class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]): + """Enables running GraphRAG global search queries as an AutoGen tool. + + This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework. + The search combines graph-based document relationships with semantic embeddings to find relevant information. + + .. note:: + This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package. + + To install: + + .. code-block:: bash + + pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]" + + Before using this tool, you must complete the GraphRAG setup and indexing process: + + 1. Follow the GraphRAG documentation to initialize your project and settings + 2. Configure and tune your prompts for the specific use case + 3. Run the indexing process to generate the required data files + 4. Ensure you have the settings.yaml file from the setup process + + Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/) + for detailed instructions on completing these prerequisite steps. + + Example usage with AssistantAgent: + + .. code-block:: python + + import asyncio + from pathlib import Path + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.tools.graphrag import GlobalSearchTool + from agentdhal_agentchat.agents import AssistantAgent + + + async def main(): + # Initialize the OpenAI client + openai_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + api_key="", + ) + + # Set up global search tool + global_tool = GlobalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml")) + + # Create assistant agent with the global search tool + assistant_agent = AssistantAgent( + name="search_assistant", + tools=[global_tool], + model_client=openai_client, + system_message=( + "You are a tool selector AI assistant using the GraphRAG framework. " + "Your primary task is to determine the appropriate search tool to call based on the user's query. " + "For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function." + ), + ) + + # Run a sample query + query = "What is the overall sentiment of the community reports?" + await Console(assistant_agent.run_stream(task=query)) + + + if __name__ == "__main__": + asyncio.run(main()) + """ + + def __init__( + self, + token_encoder: tiktoken.Encoding, + model: ChatModel, + data_config: DataConfig, + context_config: ContextConfig = _default_context_config, + mapreduce_config: MapReduceConfig = _default_mapreduce_config, + ): + super().__init__( + args_type=GlobalSearchToolArgs, + return_type=GlobalSearchToolReturn, + name="global_search_tool", + description="Perform a global search with given parameters using graphrag.", + ) + # Use the provided model + self._model = model + + # Load parquet files + community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore + entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore + report_df: pd.DataFrame = pd.read_parquet( # type: ignore + f"{data_config.input_dir}/{data_config.community_report_table}.parquet" + ) + + # Fix: Use correct argument order and types for GraphRAG API + communities = read_indexer_communities(community_df, report_df) + reports = read_indexer_reports(report_df, community_df, data_config.community_level) + entities = read_indexer_entities(entity_df, community_df, data_config.community_level) + + context_builder = GlobalCommunityContext( + community_reports=reports, + communities=communities, + entities=entities, + token_encoder=token_encoder, + ) + + context_builder_params = { + "use_community_summary": context_config.use_community_summary, + "shuffle_data": context_config.shuffle_data, + "include_community_rank": context_config.include_community_rank, + "min_community_rank": context_config.min_community_rank, + "community_rank_name": context_config.community_rank_name, + "include_community_weight": context_config.include_community_weight, + "community_weight_name": context_config.community_weight_name, + "normalize_community_weight": context_config.normalize_community_weight, + "max_tokens": context_config.max_data_tokens, + "context_name": "Reports", + } + + map_llm_params = { + "max_tokens": mapreduce_config.map_max_tokens, + "temperature": mapreduce_config.map_temperature, + "response_format": {"type": "json_object"}, + } + + reduce_llm_params = { + "max_tokens": mapreduce_config.reduce_max_tokens, + "temperature": mapreduce_config.reduce_temperature, + } + + self._search_engine = GlobalSearch( + model=self._model, + context_builder=context_builder, + token_encoder=token_encoder, + max_data_tokens=context_config.max_data_tokens, + map_llm_params=map_llm_params, + reduce_llm_params=reduce_llm_params, + allow_general_knowledge=mapreduce_config.allow_general_knowledge, + json_mode=mapreduce_config.json_mode, + context_builder_params=context_builder_params, + concurrent_coroutines=32, + response_type=mapreduce_config.response_type, + ) + + async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn: + search_result = await self._search_engine.search(args.query) + assert isinstance(search_result.response, str), "Expected response to be a string" + return GlobalSearchToolReturn(answer=search_result.response) + + @classmethod + def from_settings(cls, root_dir: str | Path, config_filepath: str | Path | None = None) -> "GlobalSearchTool": + """Create a GlobalSearchTool instance from GraphRAG settings file. + + Args: + root_dir: Path to the GraphRAG root directory + config_filepath: Path to the GraphRAG settings file (optional) + + Returns: + An initialized GlobalSearchTool instance + """ + # Load GraphRAG config + if isinstance(root_dir, str): + root_dir = Path(root_dir) + if isinstance(config_filepath, str): + config_filepath = Path(config_filepath) + config = load_config(root_dir=root_dir, config_filepath=config_filepath) + + # Get the language model configuration from the models section + chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID) + + if chat_model_config is None: + raise ValueError("default_chat_model not found in config.models") + + # Initialize token encoder based on the model being used + try: + token_encoder = tiktoken.encoding_for_model(chat_model_config.model) + except KeyError: + # Fallback to cl100k_base if model is not recognized by tiktoken + token_encoder = tiktoken.get_encoding("cl100k_base") + + # Create the LLM using ModelManager + model = ModelManager().get_or_create_chat_model( + name="global_search_model", + model_type=chat_model_config.type, + config=chat_model_config, + ) + + # Create data config from storage paths + data_config = DataConfig( + input_dir=str(config.output.base_dir), + ) + + return cls( + token_encoder=token_encoder, + model=model, + data_config=data_config, + context_config=_default_context_config, + mapreduce_config=_default_mapreduce_config, + ) diff --git a/agent_dhal/agentdhal_extensions/tools/graphrag/_local_search.py b/agent_dhal/agentdhal_extensions/tools/graphrag/_local_search.py new file mode 100644 index 0000000..54fe3e2 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/graphrag/_local_search.py @@ -0,0 +1,245 @@ +# mypy: disable-error-code="no-any-unimported,misc" +from pathlib import Path + +import pandas as pd +import tiktoken +from agentdhal_core import CancellationToken +from agentdhal_core.tools import BaseTool +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs +from graphrag.config.load_config import load_config +from graphrag.language_model.manager import ModelManager +from graphrag.language_model.protocol import ChatModel, EmbeddingModel +from graphrag.query.indexer_adapters import ( + read_indexer_entities, + read_indexer_relationships, + read_indexer_text_units, +) +from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext +from graphrag.query.structured_search.local_search.search import LocalSearch +from graphrag.vector_stores.lancedb import LanceDBVectorStore + +from ._config import LocalContextConfig, SearchConfig +from ._config import LocalDataConfig as DataConfig + +_default_context_config = LocalContextConfig() +_default_search_config = SearchConfig() + + +class LocalSearchToolArgs(BaseModel): + query: str = Field(..., description="The user query to perform local search on.") + + +class LocalSearchToolReturn(BaseModel): + answer: str = Field(..., description="The answer to the user query.") + + +class LocalSearchTool(BaseTool[LocalSearchToolArgs, LocalSearchToolReturn]): + """Enables running GraphRAG local search queries as an AutoGen tool. + + This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework. + The search combines local document context with semantic embeddings to find relevant information. + + .. note:: + This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package. + To install: + + .. code-block:: bash + + pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]" + + Before using this tool, you must complete the GraphRAG setup and indexing process: + + 1. Follow the GraphRAG documentation to initialize your project and settings + 2. Configure and tune your prompts for the specific use case + 3. Run the indexing process to generate the required data files + 4. Ensure you have the settings.yaml file from the setup process + + Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/) + for detailed instructions on completing these prerequisite steps. + + Example usage with AssistantAgent: + + .. code-block:: python + + import asyncio + from pathlib import Path + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.tools.graphrag import LocalSearchTool + from agentdhal_agentchat.agents import AssistantAgent + + + async def main(): + # Initialize the OpenAI client + openai_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + api_key="", + ) + + # Set up local search tool + local_tool = LocalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml")) + + # Create assistant agent with the local search tool + assistant_agent = AssistantAgent( + name="search_assistant", + tools=[local_tool], + model_client=openai_client, + system_message=( + "You are a tool selector AI assistant using the GraphRAG framework. " + "Your primary task is to determine the appropriate search tool to call based on the user's query. " + "For specific, detailed information about particular entities or relationships, call the 'local_search' function." + ), + ) + + # Run a sample query + query = "What does the station-master say about Dr. Becher?" + await Console(assistant_agent.run_stream(task=query)) + + + if __name__ == "__main__": + asyncio.run(main()) + + + Args: + token_encoder (tiktoken.Encoding): The tokenizer used for text encoding + model: The chat model to use for search (GraphRAG ChatModel) + embedder: The text embedding model to use (GraphRAG EmbeddingModel) + data_config (DataConfig): Configuration for data source locations and settings + context_config (LocalContextConfig, optional): Configuration for context building. Defaults to default config. + search_config (SearchConfig, optional): Configuration for search operations. Defaults to default config. + """ + + def __init__( + self, + token_encoder: tiktoken.Encoding, + model: ChatModel, # ChatModel from GraphRAG + embedder: EmbeddingModel, # EmbeddingModel from GraphRAG + data_config: DataConfig, + context_config: LocalContextConfig = _default_context_config, + search_config: SearchConfig = _default_search_config, + ): + super().__init__( + args_type=LocalSearchToolArgs, + return_type=LocalSearchToolReturn, + name="local_search_tool", + description="Perform a local search with given parameters using graphrag.", + ) + # Use the provided models + self._model = model + self._embedder = embedder + + # Load parquet files + entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore + relationship_df: pd.DataFrame = pd.read_parquet( # type: ignore + f"{data_config.input_dir}/{data_config.relationship_table}.parquet" + ) + text_unit_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.text_unit_table}.parquet") # type: ignore + community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore + + # Read data using indexer adapters + entities = read_indexer_entities(entity_df, community_df, data_config.community_level) + relationships = read_indexer_relationships(relationship_df) + text_units = read_indexer_text_units(text_unit_df) + # Set up vector store for entity embeddings + description_embedding_store = LanceDBVectorStore( + collection_name="default-entity-description", + ) + description_embedding_store.connect(db_uri=f"{data_config.input_dir}/lancedb") + + # Set up context builder + context_builder = LocalSearchMixedContext( + entities=entities, + entity_text_embeddings=description_embedding_store, + text_embedder=self._embedder, + text_units=text_units, + relationships=relationships, + token_encoder=token_encoder, + ) + + context_builder_params = { + "text_unit_prop": context_config.text_unit_prop, + "community_prop": context_config.community_prop, + "include_entity_rank": context_config.include_entity_rank, + "rank_description": context_config.rank_description, + "include_relationship_weight": context_config.include_relationship_weight, + "relationship_ranking_attribute": context_config.relationship_ranking_attribute, + "max_tokens": context_config.max_data_tokens, + } + + llm_params = { + "max_tokens": search_config.max_tokens, + "temperature": search_config.temperature, + } + + self._search_engine = LocalSearch( + model=self._model, + context_builder=context_builder, + token_encoder=token_encoder, + response_type=search_config.response_type, + context_builder_params=context_builder_params, + model_params=llm_params, + ) + + async def run(self, args: LocalSearchToolArgs, cancellation_token: CancellationToken) -> LocalSearchToolReturn: + search_result = await self._search_engine.search(args.query) # type: ignore[reportUnknownMemberType] + assert isinstance(search_result.response, str), "Expected response to be a string" + return LocalSearchToolReturn(answer=search_result.response) + + @classmethod + def from_settings(cls, root_dir: Path, config_filepath: Path | None = None) -> "LocalSearchTool": + """Create a LocalSearchTool instance from GraphRAG settings file. + + Args: + root_dir: Path to the GraphRAG root directory + config_filepath: Path to the GraphRAG settings file (optional) + + Returns: + An initialized LocalSearchTool instance + """ + # Load GraphRAG config + config = load_config(root_dir=root_dir, config_filepath=config_filepath) + + # Get the language model configurations from the models section + chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID) + embedding_model_config = config.models.get(defs.DEFAULT_EMBEDDING_MODEL_ID) + + if chat_model_config is None: + raise ValueError("default_chat_model not found in config.models") + if embedding_model_config is None: + raise ValueError("default_embedding_model not found in config.models") + + # Initialize token encoder based on the model being used + try: + token_encoder = tiktoken.encoding_for_model(chat_model_config.model) + except KeyError: + # Fallback to cl100k_base if model is not recognized by tiktoken + token_encoder = tiktoken.get_encoding("cl100k_base") + + # Create the models using ModelManager + model = ModelManager().get_or_create_chat_model( + name="local_search_model", + model_type=chat_model_config.type, + config=chat_model_config, + ) + + embedder = ModelManager().get_or_create_embedding_model( + name="local_search_embedder", + model_type=embedding_model_config.type, + config=embedding_model_config, + ) + + # Create data config from storage paths + data_config = DataConfig( + input_dir=str(config.output.base_dir), + ) + + return cls( + token_encoder=token_encoder, + model=model, + embedder=embedder, + data_config=data_config, + context_config=_default_context_config, + search_config=_default_search_config, + ) diff --git a/agent_dhal/agentdhal_extensions/tools/http/__init__.py b/agent_dhal/agentdhal_extensions/tools/http/__init__.py new file mode 100644 index 0000000..6c276b6 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/http/__init__.py @@ -0,0 +1,3 @@ +from ._http_tool import HttpTool + +__all__ = ["HttpTool"] diff --git a/agent_dhal/agentdhal_extensions/tools/http/_http_tool.py b/agent_dhal/agentdhal_extensions/tools/http/_http_tool.py new file mode 100644 index 0000000..af1a20a --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/http/_http_tool.py @@ -0,0 +1,244 @@ +import re +from typing import Any, Literal, Optional, Type + +import httpx +from agentdhal_core import CancellationToken, Component +from agentdhal_core.tools import BaseTool +from json_schema_to_pydantic import create_model +from pydantic import BaseModel, Field +from typing_extensions import Self + +DEFAULT_TIMEOUT_CONFIG = 5.0 + + +class HttpToolConfig(BaseModel): + name: str + """ + The name of the tool. + """ + description: Optional[str] + """ + A description of the tool. + """ + scheme: Literal["http", "https"] = "http" + """ + The scheme to use for the request. + """ + host: str + """ + The URL to send the request to. + """ + port: int + """ + The port to send the request to. + """ + path: str = Field(default="/") + """ + The path to send the request to. defaults to "/" + The path can accept parameters, e.g. "/{param1}/{param2}". + These parameters will be templated from the inputs args, any additional parameters will be added as query parameters or the body of the request. + """ + method: Optional[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]] = "POST" + """ + The HTTP method to use, will default to POST if not provided. + """ + headers: Optional[dict[str, Any]] + """ + A dictionary of headers to send with the request. + """ + json_schema: dict[str, Any] + """ + A JSON Schema object defining the expected parameters for the tool. + Path parameters MUST also be included in the json_schema. They must also MUST be set to string + """ + return_type: Optional[Literal["text", "json"]] = "text" + """ + The type of response to return from the tool. + """ + timeout: float = DEFAULT_TIMEOUT_CONFIG + """ + The timeout for the tool request in seconds. + """ + + +class HttpTool(BaseTool[BaseModel, Any], Component[HttpToolConfig]): + """A wrapper for using an HTTP server as a tool. + + Args: + name (str): The name of the tool. + description (str, optional): A description of the tool. + scheme (str): The scheme to use for the request. Must be either "http" or "https". + host (str): The host to send the request to. + port (int): The port to send the request to. + path (str, optional): The path to send the request to. Defaults to "/". + Can include path parameters like "/{param1}/{param2}" which will be templated from input args. + method (str, optional): The HTTP method to use, will default to POST if not provided. + Must be one of "GET", "POST", "PUT", "DELETE", "PATCH". + headers (dict[str, Any], optional): A dictionary of headers to send with the request. + json_schema (dict[str, Any]): A JSON Schema object defining the expected parameters for the tool. + Path parameters must also be included in the schema and must be strings. + return_type (Literal["text", "json"], optional): The type of response to return from the tool. + Defaults to "text". + timeout (float, optional): The timeout for HTTP requests in seconds. + Defaults to 5.0. + + .. note:: + This tool requires the :code:`http-tool` extra for the :code:`autogen-ext` package. + + To install: + + .. code-block:: bash + + pip install -U "agentdhal-agentchat" "agentdhal-ext[http-tool]" + + Example: + Simple use case:: + + import asyncio + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.messages import TextMessage + from agentdhal_core import CancellationToken + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.http import HttpTool + + # Define a JSON schema for a base64 decode tool + base64_schema = { + "type": "object", + "properties": { + "value": {"type": "string", "description": "The base64 value to decode"}, + }, + "required": ["value"], + } + + # Create an HTTP tool for the httpbin API + base64_tool = HttpTool( + name="base64_decode", + description="base64 decode a value", + scheme="https", + host="httpbin.org", + port=443, + path="/base64/{value}", + method="GET", + json_schema=base64_schema, + ) + + + async def main(): + # Create an assistant with the base64 tool + model = OpenAIChatCompletionClient(model="gpt-4") + assistant = AssistantAgent("base64_assistant", model_client=model, tools=[base64_tool]) + + # The assistant can now use the base64 tool to decode the string + response = await assistant.on_messages( + [TextMessage(content="Can you base64 decode the value 'YWJjZGU=', please?", source="user")], + CancellationToken(), + ) + print(response.chat_message) + + + asyncio.run(main()) + """ + + component_type = "tool" + component_provider_override = "agentdhal_extensions.tools.http.HttpTool" + component_config_schema = HttpToolConfig + + def __init__( + self, + name: str, + host: str, + port: int, + json_schema: dict[str, Any], + headers: Optional[dict[str, Any]] = None, + description: str = "HTTP tool", + path: str = "/", + scheme: Literal["http", "https"] = "http", + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "POST", + return_type: Literal["text", "json"] = "text", + timeout: float = DEFAULT_TIMEOUT_CONFIG, + ) -> None: + self.server_params = HttpToolConfig( + name=name, + description=description, + host=host, + port=port, + path=path, + scheme=scheme, + method=method, + headers=headers, + json_schema=json_schema, + return_type=return_type, + timeout=timeout, + ) + + # Use regex to find all path parameters, we will need those later to template the path + path_params = {match.group(1) for match in re.finditer(r"{([^}]*)}", path)} + self._path_params = path_params + + # Create the input model from the modified schema + input_model = create_model(json_schema) + + # Use Any as return type since HTTP responses can vary + base_return_type: Type[Any] = object + + super().__init__(input_model, base_return_type, name, description) + + def _to_config(self) -> HttpToolConfig: + copied_config = self.server_params.model_copy() + return copied_config + + @classmethod + def _from_config(cls, config: HttpToolConfig) -> Self: + copied_config = config.model_copy().model_dump() + return cls(**copied_config) + + async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any: + """Execute the HTTP tool with the given arguments. + + Args: + args: The validated input arguments + cancellation_token: Token for cancelling the operation + + Returns: + The response body from the HTTP call in JSON format + + Raises: + Exception: If tool execution fails + """ + + model_dump = args.model_dump() + path_params = {k: v for k, v in model_dump.items() if k in self._path_params} + # Remove path params from the model dump + for k in self._path_params: + model_dump.pop(k) + + path = self.server_params.path.format(**path_params) + + url = httpx.URL( + scheme=self.server_params.scheme, + host=self.server_params.host, + port=self.server_params.port, + path=path, + ) + timeout_config = httpx.Timeout(timeout=self.server_params.timeout) + async with httpx.AsyncClient(timeout=timeout_config) as client: + match self.server_params.method: + case "GET": + response = await client.get(url, headers=self.server_params.headers, params=model_dump) + case "PUT": + response = await client.put(url, headers=self.server_params.headers, json=model_dump) + case "DELETE": + response = await client.delete(url, headers=self.server_params.headers, params=model_dump) + case "PATCH": + response = await client.patch(url, headers=self.server_params.headers, json=model_dump) + case _: # Default case POST + response = await client.post(url, headers=self.server_params.headers, json=model_dump) + + match self.server_params.return_type: + case "text": + return response.text + case "json": + return response.json() + case _: + raise ValueError(f"Invalid return type: {self.server_params.return_type}") diff --git a/agent_dhal/agentdhal_extensions/tools/langchain/__init__.py b/agent_dhal/agentdhal_extensions/tools/langchain/__init__.py new file mode 100644 index 0000000..03af958 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/langchain/__init__.py @@ -0,0 +1,3 @@ +from ._langchain_adapter import LangChainToolAdapter + +__all__ = ["LangChainToolAdapter"] diff --git a/agent_dhal/agentdhal_extensions/tools/langchain/_langchain_adapter.py b/agent_dhal/agentdhal_extensions/tools/langchain/_langchain_adapter.py new file mode 100644 index 0000000..15cfd5a --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/langchain/_langchain_adapter.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import asyncio +import inspect +from typing import TYPE_CHECKING, Any, Callable, Dict, Type, cast + +from agentdhal_core import CancellationToken +from agentdhal_core.tools import BaseTool +from pydantic import BaseModel, Field, create_model + +if TYPE_CHECKING: + from langchain_core.tools import BaseTool as LangChainTool + + +class LangChainToolAdapter(BaseTool[BaseModel, Any]): + """Allows you to wrap a LangChain tool and make it available to AutoGen. + + .. note:: + + This class requires the :code:`langchain` extra for the :code:`autogen-ext` package. + + .. code-block:: bash + + pip install -U "agentdhal-ext[langchain]" + + + Args: + langchain_tool (LangChainTool): A LangChain tool to wrap + + Examples: + + Use the `PythonAstREPLTool` from the `langchain_experimental` package to + create a tool that allows you to interact with a Pandas DataFrame. + + .. code-block:: python + + import asyncio + import pandas as pd + from langchain_experimental.tools.python.tool import PythonAstREPLTool + from agentdhal_extensions.tools.langchain import LangChainToolAdapter + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_agentchat.messages import TextMessage + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.ui import Console + from agentdhal_core import CancellationToken + + + async def main() -> None: + df = pd.read_csv("https://raw.githubusercontent.com/pandas-dev/pandas/main/doc/data/titanic.csv") # type: ignore + tool = LangChainToolAdapter(PythonAstREPLTool(locals={"df": df})) + model_client = OpenAIChatCompletionClient(model="gpt-4o") + agent = AssistantAgent( + "assistant", + tools=[tool], + model_client=model_client, + system_message="Use the `df` variable to access the dataset.", + ) + await Console( + agent.on_messages_stream( + [TextMessage(content="What's the average age of the passengers?", source="user")], CancellationToken() + ) + ) + + + asyncio.run(main()) + + This example demonstrates how to use the `SQLDatabaseToolkit` from the `langchain_community` + package to interact with an SQLite database. + It uses the :class:`~agentdhal_agentchat.team.RoundRobinGroupChat` to iterate the single agent over multiple steps. + If you want to one step at a time, you can just call `run_stream` method of the + :class:`~agentdhal_agentchat.agents.AssistantAgent` class directly. + + .. code-block:: python + + import asyncio + import sqlite3 + + import requests + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import TextMentionTermination + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.langchain import LangChainToolAdapter + from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit + from langchain_community.utilities.sql_database import SQLDatabase + from langchain_openai import ChatOpenAI + from sqlalchemy import Engine, create_engine + from sqlalchemy.pool import StaticPool + + + def get_engine_for_chinook_db() -> Engine: + url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql" + response = requests.get(url) + sql_script = response.text + connection = sqlite3.connect(":memory:", check_same_thread=False) + connection.executescript(sql_script) + return create_engine( + "sqlite://", + creator=lambda: connection, + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + + + async def main() -> None: + # Create the engine and database wrapper. + engine = get_engine_for_chinook_db() + db = SQLDatabase(engine) + + # Create the toolkit. + llm = ChatOpenAI(temperature=0) + toolkit = SQLDatabaseToolkit(db=db, llm=llm) + + # Create the LangChain tool adapter for every tool in the toolkit. + tools = [LangChainToolAdapter(tool) for tool in toolkit.get_tools()] + + # Create the chat completion client. + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + # Create the assistant agent. + agent = AssistantAgent( + "assistant", + model_client=model_client, + tools=tools, # type: ignore + model_client_stream=True, + system_message="Respond with 'TERMINATE' if the task is completed.", + ) + + # Create termination condition. + termination = TextMentionTermination("TERMINATE") + + # Create a round-robin group chat to iterate the single agent over multiple steps. + chat = RoundRobinGroupChat([agent], termination_condition=termination) + + # Run the chat. + await Console(chat.run_stream(task="Show some tables in the database")) + + + if __name__ == "__main__": + asyncio.run(main()) + + """ + + def __init__(self, langchain_tool: LangChainTool): + self._langchain_tool: LangChainTool = langchain_tool + + # Extract name and description + name = self._langchain_tool.name + description = self._langchain_tool.description or "" + + # Determine the callable method + if hasattr(self._langchain_tool, "func") and callable(self._langchain_tool.func): # type: ignore + assert self._langchain_tool.func is not None # type: ignore + self._callable: Callable[..., Any] = self._langchain_tool.func # type: ignore + elif hasattr(self._langchain_tool, "_run") and callable(self._langchain_tool._run): # type: ignore + self._callable: Callable[..., Any] = self._langchain_tool._run # type: ignore + else: + raise AttributeError( + f"The provided LangChain tool '{name}' does not have a callable 'func' or '_run' method." + ) + + # Determine args_type + if self._langchain_tool.args_schema: # pyright: ignore + args_type = self._langchain_tool.args_schema # pyright: ignore + else: + # Infer args_type from the callable's signature + sig = inspect.signature(cast(Callable[..., Any], self._callable)) # type: ignore + fields = { + k: (v.annotation, Field(...)) + for k, v in sig.parameters.items() + if k != "self" and v.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + } + args_type = create_model(f"{name}Args", **fields) # type: ignore + # Note: type ignore is used due to a LangChain typing limitation + + # Ensure args_type is a subclass of BaseModel + if not issubclass(args_type, BaseModel): + raise ValueError(f"Failed to create a valid Pydantic v2 model for {name}") + + # Assume return_type as Any if not specified + return_type: Type[Any] = object + + super().__init__(args_type, return_type, name, description) + + async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any: + # Prepare arguments + kwargs = args.model_dump() + + # Determine if the callable is asynchronous + if inspect.iscoroutinefunction(self._callable): + return await self._callable(**kwargs) + else: + # Run in a thread to avoid blocking the event loop + return await asyncio.to_thread(self._call_sync, kwargs) + + def _call_sync(self, kwargs: Dict[str, Any]) -> Any: + return self._callable(**kwargs) diff --git a/agent_dhal/agentdhal_extensions/tools/mcp/__init__.py b/agent_dhal/agentdhal_extensions/tools/mcp/__init__.py new file mode 100644 index 0000000..fcf4148 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/mcp/__init__.py @@ -0,0 +1,22 @@ +from ._actor import McpSessionActor +from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams +from ._factory import mcp_server_tools +from ._session import create_mcp_server_session +from ._sse import SseMcpToolAdapter +from ._stdio import StdioMcpToolAdapter +from ._streamable_http import StreamableHttpMcpToolAdapter +from ._workbench import McpWorkbench + +__all__ = [ + "create_mcp_server_session", + "McpSessionActor", + "StdioMcpToolAdapter", + "StdioServerParams", + "SseMcpToolAdapter", + "SseServerParams", + "StreamableHttpMcpToolAdapter", + "StreamableHttpServerParams", + "McpServerParams", + "mcp_server_tools", + "McpWorkbench", +] diff --git a/agent_dhal/agentdhal_extensions/tools/mcp/_actor.py b/agent_dhal/agentdhal_extensions/tools/mcp/_actor.py new file mode 100644 index 0000000..f71bbd2 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/mcp/_actor.py @@ -0,0 +1,310 @@ +import asyncio +import atexit +import base64 +import io +import logging +from typing import Any, Coroutine, Dict, Mapping, TypedDict + +from agentdhal_core import Component, ComponentBase, ComponentModel, Image +from agentdhal_core.models import ( + AssistantMessage, + ChatCompletionClient, + LLMMessage, + ModelInfo, + SystemMessage, + UserMessage, +) +from PIL import Image as PILImage +from pydantic import BaseModel +from typing_extensions import Self + +from mcp import types as mcp_types +from mcp.client.session import ClientSession +from mcp.shared.context import RequestContext + +from ._config import McpServerParams +from ._session import create_mcp_server_session + +logger = logging.getLogger(__name__) + +McpResult = ( + Coroutine[Any, Any, mcp_types.ListToolsResult] + | Coroutine[Any, Any, mcp_types.CallToolResult] + | Coroutine[Any, Any, mcp_types.ListPromptsResult] + | Coroutine[Any, Any, mcp_types.ListResourcesResult] + | Coroutine[Any, Any, mcp_types.ListResourceTemplatesResult] + | Coroutine[Any, Any, mcp_types.ReadResourceResult] + | Coroutine[Any, Any, mcp_types.GetPromptResult] +) +McpFuture = asyncio.Future[McpResult] + + +def _parse_sampling_content( + content: mcp_types.TextContent | mcp_types.ImageContent | mcp_types.AudioContent, model_info: ModelInfo +) -> str | Image: + """Convert MCP content types to Autogen content types.""" + if content.type == "text": + return content.text + elif content.type == "image": + if not model_info["vision"]: + raise ValueError("Sampling model does not support image content.") + # Decode base64 image data and create PIL Image + image_data = base64.b64decode(content.data) + pil_image = PILImage.open(io.BytesIO(image_data)) + return Image.from_pil(pil_image) + else: + raise ValueError(f"Unsupported content type: {content.type}") + + +def _parse_sampling_message(message: mcp_types.SamplingMessage, model_info: ModelInfo) -> LLMMessage: + """Convert MCP sampling messages to Autogen messages.""" + content = _parse_sampling_content(message.content, model_info=model_info) + if message.role == "user": + return UserMessage( + source="user", + content=[content], + ) + elif message.role == "assistant": + assert isinstance(content, str), "Assistant messages only support string content." + return AssistantMessage( + source="assistant", + content=content, + ) + else: + raise ValueError(f"Unrecognized message role: {message.role}") + + +class McpActorArgs(TypedDict): + name: str | None + kargs: Mapping[str, Any] + + +class McpSessionActorConfig(BaseModel): + server_params: McpServerParams + model_client: ComponentModel | Dict[str, Any] | None = None + + +class McpSessionActor(ComponentBase[BaseModel], Component[McpSessionActorConfig]): + component_type = "mcp_session_actor" + component_config_schema = McpSessionActorConfig + component_provider_override = "agentdhal_extensions.tools.mcp.McpSessionActor" + + server_params: McpServerParams + + # model_config = ConfigDict(arbitrary_types_allowed=True) + + def __init__(self, server_params: McpServerParams, model_client: ChatCompletionClient | None = None) -> None: + self.server_params: McpServerParams = server_params + self._model_client = model_client + self.name = "mcp_session_actor" + self.description = "MCP session actor" + self._command_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() + self._actor_task: asyncio.Task[Any] | None = None + self._shutdown_future: asyncio.Future[Any] | None = None + self._active = False + self._initialize_result: mcp_types.InitializeResult | None = None + atexit.register(self._sync_shutdown) + + @property + def initialize_result(self) -> mcp_types.InitializeResult | None: + return self._initialize_result + + async def initialize(self) -> None: + if not self._active: + self._active = True + self._actor_task = asyncio.create_task(self._run_actor()) + + async def call(self, type: str, args: McpActorArgs | None = None) -> McpFuture: + if not self._active: + raise RuntimeError("MCP Actor not running, call initialize() first") + if self._actor_task and self._actor_task.done(): + raise RuntimeError("MCP actor task crashed", self._actor_task.exception()) + fut: asyncio.Future[McpFuture] = asyncio.Future() + if type in {"list_tools", "list_prompts", "list_resources", "list_resource_templates", "shutdown"}: + await self._command_queue.put({"type": type, "future": fut}) + res = await fut + elif type in {"call_tool", "read_resource", "get_prompt"}: + if args is None: + raise ValueError(f"args is required for {type}") + name = args.get("name", None) + kwargs = args.get("kargs", {}) + if type == "call_tool" and name is None: + raise ValueError("name is required for call_tool") + elif type == "read_resource": + uri = kwargs.get("uri", None) + if uri is None: + raise ValueError("uri is required for read_resource") + await self._command_queue.put({"type": type, "uri": uri, "future": fut}) + elif type == "get_prompt": + if name is None: + raise ValueError("name is required for get_prompt") + prompt_args = kwargs.get("arguments", None) + await self._command_queue.put({"type": type, "name": name, "args": prompt_args, "future": fut}) + else: # call_tool + await self._command_queue.put({"type": type, "name": name, "args": kwargs, "future": fut}) + res = await fut + else: + raise ValueError(f"Unknown command type: {type}") + return res + + async def close(self) -> None: + if not self._active or self._actor_task is None: + return + self._shutdown_future = asyncio.Future() + await self._command_queue.put({"type": "shutdown", "future": self._shutdown_future}) + await self._shutdown_future + await self._actor_task + self._active = False + + async def _sampling_callback( + self, + context: RequestContext[ClientSession, Any], + params: mcp_types.CreateMessageRequestParams, + ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData: + """Handle sampling requests using the provided model client.""" + if self._model_client is None: + # Return an error when no model client is available + return mcp_types.ErrorData( + code=mcp_types.INVALID_REQUEST, + message="No model client available for sampling.", + data=None, + ) + + llm_messages: list[LLMMessage] = [] + + try: + if params.systemPrompt: + llm_messages.append(SystemMessage(content=params.systemPrompt)) + + for mcp_message in params.messages: + llm_messages.append(_parse_sampling_message(mcp_message, model_info=self._model_client.model_info)) + + except Exception as e: + return mcp_types.ErrorData( + code=mcp_types.INVALID_PARAMS, + message="Error processing sampling messages.", + data=f"{type(e).__name__}: {e}", + ) + + try: + result = await self._model_client.create(messages=llm_messages) + + content = result.content + if not isinstance(content, str): + content = str(content) + + return mcp_types.CreateMessageResult( + role="assistant", + content=mcp_types.TextContent(type="text", text=content), + model=self._model_client.model_info["family"], + stopReason=result.finish_reason, + ) + except Exception as e: + return mcp_types.ErrorData( + code=mcp_types.INTERNAL_ERROR, + message="Error sampling from model client.", + data=f"{type(e).__name__}: {e}", + ) + + async def _run_actor(self) -> None: + result: McpResult + try: + async with create_mcp_server_session( + self.server_params, sampling_callback=self._sampling_callback + ) as session: + # Save the initialize result + self._initialize_result = await session.initialize() + while True: + cmd = await self._command_queue.get() + if cmd["type"] == "shutdown": + cmd["future"].set_result("ok") + break + elif cmd["type"] == "call_tool": + try: + result = session.call_tool(name=cmd["name"], arguments=cmd["args"]) + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + elif cmd["type"] == "read_resource": + try: + result = session.read_resource(uri=cmd["uri"]) + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + elif cmd["type"] == "get_prompt": + try: + result = session.get_prompt(name=cmd["name"], arguments=cmd["args"]) + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + elif cmd["type"] == "list_tools": + try: + result = session.list_tools() + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + elif cmd["type"] == "list_prompts": + try: + result = session.list_prompts() + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + elif cmd["type"] == "list_resources": + try: + result = session.list_resources() + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + elif cmd["type"] == "list_resource_templates": + try: + result = session.list_resource_templates() + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + except Exception as e: + if self._shutdown_future and not self._shutdown_future.done(): + self._shutdown_future.set_exception(e) + else: + logger.exception("Exception in MCP actor task") + finally: + self._active = False + self._actor_task = None + + def _sync_shutdown(self) -> None: + if not self._active or self._actor_task is None: + return + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No loop available — interpreter is likely shutting down + return + + if loop.is_closed(): + return + + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + + def _to_config(self) -> McpSessionActorConfig: + """ + Convert the adapter to its configuration representation. + + Returns: + McpSessionConfig: The configuration of the adapter. + """ + return McpSessionActorConfig(server_params=self.server_params) + + @classmethod + def _from_config(cls, config: McpSessionActorConfig) -> Self: + """ + Create an instance of McpSessionActor from its configuration. + + Args: + config (McpSessionConfig): The configuration of the adapter. + + Returns: + McpSessionActor: An instance of SseMcpToolAdapter. + """ + return cls(server_params=config.server_params) diff --git a/agent_dhal/agentdhal_extensions/tools/mcp/_base.py b/agent_dhal/agentdhal_extensions/tools/mcp/_base.py new file mode 100644 index 0000000..bf655f9 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/mcp/_base.py @@ -0,0 +1,190 @@ +import asyncio +import builtins +import json +from abc import ABC +from typing import Any, Dict, Generic, Sequence, Type, TypeVar + +from agentdhal_core import CancellationToken +from agentdhal_core.tools import BaseTool +from agentdhal_core.utils import schema_to_pydantic_model +from pydantic import BaseModel +from pydantic.networks import AnyUrl + +from mcp import ClientSession, Tool +from mcp.types import AudioContent, ContentBlock, EmbeddedResource, ImageContent, ResourceLink, TextContent + +from ._config import McpServerParams +from ._session import create_mcp_server_session + +TServerParams = TypeVar("TServerParams", bound=McpServerParams) + + +class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]): + """ + Base adapter class for MCP tools to make them compatible with AutoGen. + + Args: + server_params (TServerParams): Parameters for the MCP server connection. + tool (Tool): The MCP tool to wrap. + """ + + component_type = "tool" + + def __init__(self, server_params: TServerParams, tool: Tool, session: ClientSession | None = None) -> None: + self._tool = tool + self._server_params = server_params + self._session = session + + # Extract name and description + name = tool.name + description = tool.description or "" + + # Create the input model from the tool's schema + input_model = schema_to_pydantic_model(tool.inputSchema) + + # Use Any as return type since MCP tool returns can vary + return_type: Type[Any] = object + + super().__init__(input_model, return_type, name, description) + + async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any: + """ + Run the MCP tool with the provided arguments. + + Args: + args (BaseModel): The arguments to pass to the tool. + cancellation_token (CancellationToken): Token to signal cancellation. + + Returns: + Any: The result of the tool execution. + + Raises: + Exception: If the operation is cancelled or the tool execution fails. + """ + # Convert the input model to a dictionary + # Exclude unset values to avoid sending them to the MCP servers which may cause errors + # for many servers. + kwargs = args.model_dump(exclude_unset=True) + + if self._session is not None: + # If a session is provided, use it directly. + session = self._session + return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session) + + async with create_mcp_server_session(self._server_params) as session: + await session.initialize() + return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session) + + def _normalize_payload_to_content_list(self, payload: Sequence[ContentBlock]) -> list[ContentBlock]: + """ + Normalizes a raw tool output payload into a list of content items. + - If payload is already a sequence of ContentBlock items, it's converted to a list and returned. + - If payload is a single ContentBlock item, it's wrapped in a list. + - If payload is a string, it's wrapped in [TextContent(text=payload)]. + - Otherwise, the payload is stringified and wrapped in [TextContent(text=str(payload))]. + """ + if isinstance(payload, Sequence) and all( + isinstance(item, (TextContent, ImageContent, EmbeddedResource, AudioContent, ResourceLink)) + for item in payload + ): + return list(payload) + elif isinstance(payload, (TextContent, ImageContent, EmbeddedResource, AudioContent, ResourceLink)): + return [payload] + elif isinstance(payload, str): + return [TextContent(text=payload, type="text")] + else: + return [TextContent(text=str(payload), type="text")] + + async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession) -> Any: + exceptions_to_catch: tuple[Type[BaseException], ...] + if hasattr(builtins, "ExceptionGroup"): + exceptions_to_catch = (asyncio.CancelledError, builtins.ExceptionGroup) + else: + exceptions_to_catch = (asyncio.CancelledError,) + + try: + if cancellation_token.is_cancelled(): + raise asyncio.CancelledError("Operation cancelled") + + result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=args)) + cancellation_token.link_future(result_future) + result = await result_future + + normalized_content_list = self._normalize_payload_to_content_list(result.content) + + if result.isError: + serialized_error_message = self.return_value_as_string(normalized_content_list) + raise Exception(serialized_error_message) + return normalized_content_list + + except exceptions_to_catch: + # Re-raise these specific exception types directly. + raise + + @classmethod + async def from_server_params(cls, server_params: TServerParams, tool_name: str) -> "McpToolAdapter[TServerParams]": + """ + Create an instance of McpToolAdapter from server parameters and tool name. + + Args: + server_params (TServerParams): Parameters for the MCP server connection. + tool_name (str): The name of the tool to wrap. + + Returns: + McpToolAdapter[TServerParams]: An instance of McpToolAdapter. + + Raises: + ValueError: If the tool with the specified name is not found. + """ + async with create_mcp_server_session(server_params) as session: + await session.initialize() + + tools_response = await session.list_tools() + matching_tool = next((t for t in tools_response.tools if t.name == tool_name), None) + + if matching_tool is None: + raise ValueError( + f"Tool '{tool_name}' not found, available tools: {', '.join([t.name for t in tools_response.tools])}" + ) + + return cls(server_params=server_params, tool=matching_tool) + + def return_value_as_string(self, value: list[Any]) -> str: + """Return a string representation of the result.""" + + def serialize_item(item: Any) -> dict[str, Any]: + if isinstance(item, (TextContent, ImageContent, AudioContent)): + dumped = item.model_dump() + # Remove the 'meta' field if it exists and is None (for backward compatibility) + if dumped.get("meta") is None: + dumped.pop("meta", None) + return dumped + elif isinstance(item, EmbeddedResource): + type = item.type + resource = {} + for key, val in item.resource.model_dump().items(): + # Skip 'meta' field if it's None (for backward compatibility) + if key == "meta" and val is None: + continue + if isinstance(val, AnyUrl): + resource[key] = str(val) + else: + resource[key] = val + dumped_annotations = item.annotations.model_dump() if item.annotations else None + # Remove 'meta' from annotations if it exists and is None + if dumped_annotations and dumped_annotations.get("meta") is None: + dumped_annotations.pop("meta", None) + return {"type": type, "resource": resource, "annotations": dumped_annotations} + elif isinstance(item, ResourceLink): + dumped = item.model_dump() + # Remove the 'meta' field if it exists and is None (for backward compatibility) + if dumped.get("meta") is None: + dumped.pop("meta", None) + # Convert AnyUrl to string for JSON serialization + if "uri" in dumped and isinstance(dumped["uri"], AnyUrl): + dumped["uri"] = str(dumped["uri"]) + return dumped + else: + return {} + + return json.dumps([serialize_item(item) for item in value]) diff --git a/agent_dhal/agentdhal_extensions/tools/mcp/_config.py b/agent_dhal/agentdhal_extensions/tools/mcp/_config.py new file mode 100644 index 0000000..d7884f4 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/mcp/_config.py @@ -0,0 +1,42 @@ +from typing import Any, Literal + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from mcp import StdioServerParameters + + +class StdioServerParams(StdioServerParameters): + """Parameters for connecting to an MCP server over STDIO.""" + + type: Literal["StdioServerParams"] = "StdioServerParams" + + read_timeout_seconds: float = 5 + + +class SseServerParams(BaseModel): + """Parameters for connecting to an MCP server over SSE.""" + + type: Literal["SseServerParams"] = "SseServerParams" + + url: str # The SSE endpoint URL. + headers: dict[str, Any] | None = None # Optional headers to include in requests. + timeout: float = 5 # HTTP timeout for regular operations. + sse_read_timeout: float = 60 * 5 # Timeout for SSE read operations. + + +class StreamableHttpServerParams(BaseModel): + """Parameters for connecting to an MCP server over Streamable HTTP.""" + + type: Literal["StreamableHttpServerParams"] = "StreamableHttpServerParams" + + url: str # The endpoint URL. + headers: dict[str, Any] | None = None # Optional headers to include in requests. + timeout: float = 30.0 # HTTP timeout for regular operations in seconds. + sse_read_timeout: float = 300.0 # Timeout for SSE read operations in seconds. + terminate_on_close: bool = True + + +McpServerParams = Annotated[ + StdioServerParams | SseServerParams | StreamableHttpServerParams, Field(discriminator="type") +] diff --git a/agent_dhal/agentdhal_extensions/tools/mcp/_factory.py b/agent_dhal/agentdhal_extensions/tools/mcp/_factory.py new file mode 100644 index 0000000..d03675e --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/mcp/_factory.py @@ -0,0 +1,214 @@ +from mcp import ClientSession + +from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams +from ._session import create_mcp_server_session +from ._sse import SseMcpToolAdapter +from ._stdio import StdioMcpToolAdapter +from ._streamable_http import StreamableHttpMcpToolAdapter + + +async def mcp_server_tools( + server_params: McpServerParams, + session: ClientSession | None = None, +) -> list[StdioMcpToolAdapter | SseMcpToolAdapter | StreamableHttpMcpToolAdapter]: + """Creates a list of MCP tool adapters that can be used with AutoGen agents. + + .. warning:: + + Only connect to trusted MCP servers, especially when using + `StdioServerParams` as it executes commands in the local environment. + + This factory function connects to an MCP server and returns adapters for all available tools. + The adapters can be directly assigned to an AutoGen agent's tools list. + + .. note:: + + To use this function, you need to install `mcp` extra for the `autogen-ext` package. + + .. code-block:: bash + + pip install -U "agentdhal-ext[mcp]" + + Args: + server_params (McpServerParams): Connection parameters for the MCP server. + Can be either StdioServerParams for command-line tools or + SseServerParams and StreamableHttpServerParams for HTTP/SSE services. + session (ClientSession | None): Optional existing session to use. This is used + when you want to reuse an existing connection to the MCP server. The session + will be reused when creating the MCP tool adapters. + + Returns: + list[StdioMcpToolAdapter | SseMcpToolAdapter | StreamableHttpMcpToolAdapter]: + A list of tool adapters ready to use with AutoGen agents. + + Examples: + + **Local file system MCP service over standard I/O example:** + + Install the filesystem server package from npm (requires Node.js 16+ and npm). + + .. code-block:: bash + + npm install -g @modelcontextprotocol/server-filesystem + + Create an agent that can use all tools from the local filesystem MCP server. + + .. code-block:: python + + import asyncio + from pathlib import Path + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.mcp import StdioServerParams, mcp_server_tools + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_core import CancellationToken + + + async def main() -> None: + # Setup server params for local filesystem access + desktop = str(Path.home() / "Desktop") + server_params = StdioServerParams( + command="npx.cmd", args=["-y", "@modelcontextprotocol/server-filesystem", desktop] + ) + + # Get all available tools from the server + tools = await mcp_server_tools(server_params) + + # Create an agent that can use all the tools + agent = AssistantAgent( + name="file_manager", + model_client=OpenAIChatCompletionClient(model="gpt-4"), + tools=tools, # type: ignore + ) + + # The agent can now use any of the filesystem tools + await agent.run(task="Create a file called test.txt with some content", cancellation_token=CancellationToken()) + + + if __name__ == "__main__": + asyncio.run(main()) + + **Local fetch MCP service over standard I/O example:** + + Install the `mcp-server-fetch` package. + + .. code-block:: bash + + pip install mcp-server-fetch + + Create an agent that can use the `fetch` tool from the local MCP server. + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.mcp import StdioServerParams, mcp_server_tools + + + async def main() -> None: + # Get the fetch tool from mcp-server-fetch. + fetch_mcp_server = StdioServerParams(command="uvx", args=["mcp-server-fetch"]) + tools = await mcp_server_tools(fetch_mcp_server) + + # Create an agent that can use the fetch tool. + model_client = OpenAIChatCompletionClient(model="gpt-4o") + agent = AssistantAgent(name="fetcher", model_client=model_client, tools=tools, reflect_on_tool_use=True) # type: ignore + + # Let the agent fetch the content of a URL and summarize it. + result = await agent.run(task="Summarize the content of https://en.wikipedia.org/wiki/Seattle") + print(result.messages[-1]) + + + asyncio.run(main()) + + **Sharing an MCP client session across multiple tools:** + + You can create a single MCP client session and share it across multiple tools. + This is sometimes required when the server maintains a session state + (e.g., a browser state) that should be reused for multiple requests. + + The following example show how to create a single MCP client session + to a local `Playwright `_ + server and share it across multiple tools. + + + .. code-block:: python + + import asyncio + + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.conditions import TextMentionTermination + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.mcp import StdioServerParams, create_mcp_server_session, mcp_server_tools + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o", parallel_tool_calls=False) # type: ignore + params = StdioServerParams( + command="npx", + args=["@playwright/mcp@latest"], + read_timeout_seconds=60, + ) + async with create_mcp_server_session(params) as session: + await session.initialize() + tools = await mcp_server_tools(server_params=params, session=session) + print(f"Tools: {[tool.name for tool in tools]}") + + agent = AssistantAgent( + name="Assistant", + model_client=model_client, + tools=tools, # type: ignore + ) + + termination = TextMentionTermination("TERMINATE") + team = RoundRobinGroupChat([agent], termination_condition=termination) + await Console( + team.run_stream( + task="Go to https://ekzhu.com/, visit the first link in the page, then tell me about the linked page." + ) + ) + + + asyncio.run(main()) + + + **Remote MCP service over SSE example:** + + .. code-block:: python + + from agentdhal_extensions.tools.mcp import SseServerParams, mcp_server_tools + + + async def main() -> None: + # Setup server params for remote service + server_params = SseServerParams(url="https://api.example.com/mcp", headers={"Authorization": "Bearer token"}) + + # Get all available tools + tools = await mcp_server_tools(server_params) + + # Create an agent with all tools + agent = AssistantAgent(name="tool_user", model_client=OpenAIChatCompletionClient(model="gpt-4"), tools=tools) # type: ignore + + For more examples and detailed usage, see the samples directory in the package repository. + """ + if session is None: + async with create_mcp_server_session(server_params) as temp_session: + await temp_session.initialize() + + tools = await temp_session.list_tools() + else: + tools = await session.list_tools() + + if isinstance(server_params, StdioServerParams): + return [StdioMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools] + elif isinstance(server_params, SseServerParams): + return [SseMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools] + elif isinstance(server_params, StreamableHttpServerParams): + return [ + StreamableHttpMcpToolAdapter(server_params=server_params, tool=tool, session=session) + for tool in tools.tools + ] + raise ValueError(f"Unsupported server params type: {type(server_params)}") diff --git a/agent_dhal/agentdhal_extensions/tools/mcp/_session.py b/agent_dhal/agentdhal_extensions/tools/mcp/_session.py new file mode 100644 index 0000000..04ea17f --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/mcp/_session.py @@ -0,0 +1,55 @@ +from contextlib import asynccontextmanager +from datetime import timedelta +from typing import AsyncGenerator + +from mcp import ClientSession +from mcp.client.session import SamplingFnT +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamablehttp_client + +from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams + + +@asynccontextmanager +async def create_mcp_server_session( + server_params: McpServerParams, sampling_callback: SamplingFnT | None = None +) -> AsyncGenerator[ClientSession, None]: + """Create an MCP client session for the given server parameters.""" + if isinstance(server_params, StdioServerParams): + async with stdio_client(server_params) as (read, write): + async with ClientSession( + read_stream=read, + write_stream=write, + read_timeout_seconds=timedelta(seconds=server_params.read_timeout_seconds), + sampling_callback=sampling_callback, + ) as session: + yield session + elif isinstance(server_params, SseServerParams): + async with sse_client(**server_params.model_dump(exclude={"type"})) as (read, write): + async with ClientSession( + read_stream=read, + write_stream=write, + read_timeout_seconds=timedelta(seconds=server_params.sse_read_timeout), + sampling_callback=sampling_callback, + ) as session: + yield session + elif isinstance(server_params, StreamableHttpServerParams): + # Convert float seconds to timedelta for the streamablehttp_client + params_dict = server_params.model_dump(exclude={"type"}) + params_dict["timeout"] = timedelta(seconds=server_params.timeout) + params_dict["sse_read_timeout"] = timedelta(seconds=server_params.sse_read_timeout) + + async with streamablehttp_client(**params_dict) as ( + read, + write, + session_id_callback, # type: ignore[assignment, unused-variable] + ): + # TODO: Handle session_id_callback if needed + async with ClientSession( + read_stream=read, + write_stream=write, + read_timeout_seconds=timedelta(seconds=server_params.sse_read_timeout), + sampling_callback=sampling_callback, + ) as session: + yield session diff --git a/agent_dhal/agentdhal_extensions/tools/mcp/_sse.py b/agent_dhal/agentdhal_extensions/tools/mcp/_sse.py new file mode 100644 index 0000000..6e45a5b --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/mcp/_sse.py @@ -0,0 +1,116 @@ +from agentdhal_core import Component +from pydantic import BaseModel +from typing_extensions import Self + +from mcp import ClientSession, Tool + +from ._base import McpToolAdapter +from ._config import SseServerParams + + +class SseMcpToolAdapterConfig(BaseModel): + """Configuration for the MCP tool adapter.""" + + server_params: SseServerParams + tool: Tool + + +class SseMcpToolAdapter( + McpToolAdapter[SseServerParams], + Component[SseMcpToolAdapterConfig], +): + """ + Allows you to wrap an MCP tool running over Server-Sent Events (SSE) and make it available to AutoGen. + + This adapter enables using MCP-compatible tools that communicate over HTTP with SSE + with AutoGen agents. Common use cases include integrating with remote MCP services, + cloud-based tools, and web APIs that implement the Model Context Protocol (MCP). + + .. note:: + + To use this class, you need to install `mcp` extra for the `autogen-ext` package. + + .. code-block:: bash + + pip install -U "agentdhal-ext[mcp]" + + Args: + server_params (SseServerParameters): Parameters for the MCP server connection, + including URL, headers, and timeouts. + tool (Tool): The MCP tool to wrap. + session (ClientSession, optional): The MCP client session to use. If not provided, + it will create a new session. This is useful for testing or when you want to + manage the session lifecycle yourself. + + Examples: + Use a remote translation service that implements MCP over SSE to create tools + that allow AutoGen agents to perform translations: + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.mcp import SseMcpToolAdapter, SseServerParams + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.ui import Console + from agentdhal_core import CancellationToken + + + async def main() -> None: + # Create server params for the remote MCP service + server_params = SseServerParams( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer your-api-key", "Content-Type": "application/json"}, + timeout=30, # Connection timeout in seconds + ) + + # Get the translation tool from the server + adapter = await SseMcpToolAdapter.from_server_params(server_params, "translate") + + # Create an agent that can use the translation tool + model_client = OpenAIChatCompletionClient(model="gpt-4") + agent = AssistantAgent( + name="translator", + model_client=model_client, + tools=[adapter], + system_message="You are a helpful translation assistant.", + ) + + # Let the agent translate some text + await Console( + agent.run_stream(task="Translate 'Hello, how are you?' to Spanish", cancellation_token=CancellationToken()) + ) + + + if __name__ == "__main__": + asyncio.run(main()) + + """ + + component_config_schema = SseMcpToolAdapterConfig + component_provider_override = "agentdhal_extensions.tools.mcp.SseMcpToolAdapter" + + def __init__(self, server_params: SseServerParams, tool: Tool, session: ClientSession | None = None) -> None: + super().__init__(server_params=server_params, tool=tool, session=session) + + def _to_config(self) -> SseMcpToolAdapterConfig: + """ + Convert the adapter to its configuration representation. + + Returns: + SseMcpToolAdapterConfig: The configuration of the adapter. + """ + return SseMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool) + + @classmethod + def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self: + """ + Create an instance of SseMcpToolAdapter from its configuration. + + Args: + config (SseMcpToolAdapterConfig): The configuration of the adapter. + + Returns: + SseMcpToolAdapter: An instance of SseMcpToolAdapter. + """ + return cls(server_params=config.server_params, tool=config.tool) diff --git a/agent_dhal/agentdhal_extensions/tools/mcp/_stdio.py b/agent_dhal/agentdhal_extensions/tools/mcp/_stdio.py new file mode 100644 index 0000000..23379e7 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/mcp/_stdio.py @@ -0,0 +1,74 @@ +from agentdhal_core import Component +from pydantic import BaseModel +from typing_extensions import Self + +from mcp import ClientSession, Tool + +from ._base import McpToolAdapter +from ._config import StdioServerParams + + +class StdioMcpToolAdapterConfig(BaseModel): + """Configuration for the MCP tool adapter.""" + + server_params: StdioServerParams + tool: Tool + + +class StdioMcpToolAdapter( + McpToolAdapter[StdioServerParams], + Component[StdioMcpToolAdapterConfig], +): + """Allows you to wrap an MCP tool running over STDIO and make it available to AutoGen. + + This adapter enables using MCP-compatible tools that communicate over standard input/output + with AutoGen agents. Common use cases include wrapping command-line tools and local services + that implement the Model Context Protocol (MCP). + + .. note:: + + To use this class, you need to install `mcp` extra for the `autogen-ext` package. + + .. code-block:: bash + + pip install -U "agentdhal-ext[mcp]" + + + Args: + server_params (StdioServerParams): Parameters for the MCP server connection, + including command to run and its arguments + tool (Tool): The MCP tool to wrap + session (ClientSession, optional): The MCP client session to use. If not provided, + a new session will be created. This is useful for testing or when you want to + manage the session lifecycle yourself. + + See :func:`~agentdhal_extensions.tools.mcp.mcp_server_tools` for examples. + """ + + component_config_schema = StdioMcpToolAdapterConfig + component_provider_override = "agentdhal_extensions.tools.mcp.StdioMcpToolAdapter" + + def __init__(self, server_params: StdioServerParams, tool: Tool, session: ClientSession | None = None) -> None: + super().__init__(server_params=server_params, tool=tool, session=session) + + def _to_config(self) -> StdioMcpToolAdapterConfig: + """ + Convert the adapter to its configuration representation. + + Returns: + StdioMcpToolAdapterConfig: The configuration of the adapter. + """ + return StdioMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool) + + @classmethod + def _from_config(cls, config: StdioMcpToolAdapterConfig) -> Self: + """ + Create an instance of StdioMcpToolAdapter from its configuration. + + Args: + config (StdioMcpToolAdapterConfig): The configuration of the adapter. + + Returns: + StdioMcpToolAdapter: An instance of StdioMcpToolAdapter. + """ + return cls(server_params=config.server_params, tool=config.tool) diff --git a/agent_dhal/agentdhal_extensions/tools/mcp/_streamable_http.py b/agent_dhal/agentdhal_extensions/tools/mcp/_streamable_http.py new file mode 100644 index 0000000..b34cfd7 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/mcp/_streamable_http.py @@ -0,0 +1,121 @@ +from agentdhal_core import Component +from pydantic import BaseModel +from typing_extensions import Self + +from mcp import ClientSession, Tool + +from ._base import McpToolAdapter +from ._config import StreamableHttpServerParams + + +class StreamableHttpMcpToolAdapterConfig(BaseModel): + """Configuration for the MCP tool adapter.""" + + server_params: StreamableHttpServerParams + tool: Tool + + +class StreamableHttpMcpToolAdapter( + McpToolAdapter[StreamableHttpServerParams], + Component[StreamableHttpMcpToolAdapterConfig], +): + """ + Allows you to wrap an MCP tool running over Streamable HTTP and make it available to AutoGen. + + This adapter enables using MCP-compatible tools that communicate over Streamable HTTP + with AutoGen agents. Common use cases include integrating with remote MCP services, + cloud-based tools, and web APIs that implement the Model Context Protocol (MCP). + + .. note:: + + To use this class, you need to install `mcp` extra for the `autogen-ext` package. + + .. code-block:: bash + + pip install -U "agentdhal-ext[mcp]" + + + Args: + server_params (StreamableHttpServerParams): Parameters for the MCP server connection, + including URL, headers, and timeouts. + tool (Tool): The MCP tool to wrap. + session (ClientSession, optional): The MCP client session to use. If not provided, + it will create a new session. This is useful for testing or when you want to + manage the session lifecycle yourself. + + Examples: + Use a remote translation service that implements MCP over Streamable HTTP to + create tools that allow AutoGen agents to perform translations: + + .. code-block:: python + + import asyncio + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.mcp import StreamableHttpMcpToolAdapter, StreamableHttpServerParams + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.ui import Console + from agentdhal_core import CancellationToken + + + async def main() -> None: + # Create server params for the remote MCP service + server_params = StreamableHttpServerParams( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer your-api-key", "Content-Type": "application/json"}, + timeout=30.0, # HTTP timeout in seconds + sse_read_timeout=300.0, # SSE read timeout in seconds (5 minutes) + terminate_on_close=True, + ) + + # Get the translation tool from the server + adapter = await StreamableHttpMcpToolAdapter.from_server_params(server_params, "translate") + + # Create an agent that can use the translation tool + model_client = OpenAIChatCompletionClient(model="gpt-4") + agent = AssistantAgent( + name="translator", + model_client=model_client, + tools=[adapter], + system_message="You are a helpful translation assistant.", + ) + + # Let the agent translate some text + await Console( + agent.run_stream(task="Translate 'Hello, how are you?' to Spanish", cancellation_token=CancellationToken()) + ) + + + if __name__ == "__main__": + asyncio.run(main()) + + """ + + component_config_schema = StreamableHttpMcpToolAdapterConfig + component_provider_override = "agentdhal_extensions.tools.mcp.StreamableHttpMcpToolAdapter" + + def __init__( + self, server_params: StreamableHttpServerParams, tool: Tool, session: ClientSession | None = None + ) -> None: + super().__init__(server_params=server_params, tool=tool, session=session) + + def _to_config(self) -> StreamableHttpMcpToolAdapterConfig: + """ + Convert the adapter to its configuration representation. + + Returns: + StreamableHttpMcpToolAdapterConfig: The configuration of the adapter. + """ + return StreamableHttpMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool) + + @classmethod + def _from_config(cls, config: StreamableHttpMcpToolAdapterConfig) -> Self: + """ + Create an instance of StreamableHttpMcpToolAdapter from its configuration. + + Args: + config (StreamableHttpMcpToolAdapterConfig): The configuration of the adapter. + + Returns: + StreamableHttpMcpToolAdapter: An instance of StreamableHttpMcpToolAdapter. + """ + return cls(server_params=config.server_params, tool=config.tool) diff --git a/agent_dhal/agentdhal_extensions/tools/mcp/_workbench.py b/agent_dhal/agentdhal_extensions/tools/mcp/_workbench.py new file mode 100644 index 0000000..9317463 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/mcp/_workbench.py @@ -0,0 +1,518 @@ +import asyncio +import builtins +import warnings +from typing import Any, Dict, List, Literal, Mapping, Optional + +from agentdhal_core import CancellationToken, Component, ComponentModel, Image, trace_tool_span +from agentdhal_core.models import ChatCompletionClient +from agentdhal_core.tools import ( + ImageResultContent, + ParametersSchema, + TextResultContent, + ToolOverride, + ToolResult, + ToolSchema, + Workbench, +) +from pydantic import BaseModel, Field +from typing_extensions import Self + +from mcp.types import ( + CallToolResult, + EmbeddedResource, + GetPromptResult, + ImageContent, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + ReadResourceResult, + TextContent, +) + +from ._actor import McpSessionActor +from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams + + +class McpWorkbenchConfig(BaseModel): + server_params: McpServerParams + tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict) + model_client: ComponentModel | Dict[str, Any] | None = None + + +class McpWorkbenchState(BaseModel): + type: Literal["McpWorkBenchState"] = "McpWorkBenchState" + + +class McpWorkbench(Workbench, Component[McpWorkbenchConfig]): + """A workbench that wraps an MCP server and provides an interface + to list and call tools provided by the server. + + .. warning:: + + Only connect to trusted MCP servers, especially when using + `StdioServerParams` as it executes commands in the local environment. + + This workbench should be used as a context manager to ensure proper + initialization and cleanup of the underlying MCP session. + + .. list-table:: MCP Support + :header-rows: 1 + :widths: 30 70 + + * - MCP Capability + - Supported Features + * - Tools + - list_tools, call_tool + * - Resources + - list_resources, read_resource + * - ResourceTemplates + - list_resource_templates, read_resource_template + * - Prompts + - list_prompts, get_prompt + * - Sampling + - Optional support via model_client + * - Roots + - not supported + * - Ellicitation + - not supported + + Args: + server_params (McpServerParams): The parameters to connect to the MCP server. + This can be either a :class:`StdioServerParams` or :class:`SseServerParams`. + tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool + names to override configurations for name and/or description. This allows + customizing how server tools appear to consumers while maintaining the underlying + tool functionality. + model_client: Optional chat completion client to handle sampling requests + from MCP servers that support the sampling capability. This allows MCP + servers to request text generation from a language model during tool + execution. If not provided, sampling requests will return an error. + + Raises: + ValueError: If there are conflicts in tool override names. + + Examples: + + Here is a simple example of how to use the workbench with a `mcp-server-fetch` server: + + .. code-block:: python + + import asyncio + + from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams + + + async def main() -> None: + params = StdioServerParams( + command="uvx", + args=["mcp-server-fetch"], + read_timeout_seconds=60, + ) + + # You can also use `start()` and `stop()` to manage the session. + async with McpWorkbench(server_params=params) as workbench: + tools = await workbench.list_tools() + print(tools) + result = await workbench.call_tool(tools[0]["name"], {"url": "https://github.com/"}) + print(result) + + + asyncio.run(main()) + + Example of using tool overrides: + + .. code-block:: python + + import asyncio + from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams + from agentdhal_core.tools import ToolOverride + + + async def main() -> None: + params = StdioServerParams( + command="uvx", + args=["mcp-server-fetch"], + read_timeout_seconds=60, + ) + + # Override the fetch tool's name and description + overrides = { + "fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool with better error handling") + } + + async with McpWorkbench(server_params=params, tool_overrides=overrides) as workbench: + tools = await workbench.list_tools() + # The tool will now appear as "web_fetch" with the new description + print(tools) + # Call the overridden tool + result = await workbench.call_tool("web_fetch", {"url": "https://github.com/"}) + print(result) + + + asyncio.run(main()) + + Example of using the workbench with the `GitHub MCP Server `_: + + .. code-block:: python + + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") + server_params = StdioServerParams( + command="docker", + args=[ + "run", + "-i", + "--rm", + "-e", + "GITHUB_PERSONAL_ACCESS_TOKEN", + "ghcr.io/github/github-mcp-server", + ], + env={ + "GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", + }, + ) + async with McpWorkbench(server_params) as mcp: + agent = AssistantAgent( + "github_assistant", + model_client=model_client, + workbench=mcp, + reflect_on_tool_use=True, + model_client_stream=True, + ) + await Console(agent.run_stream(task="Is there a repository named Autogen")) + + + asyncio.run(main()) + + Example of using the workbench with the `Playwright MCP Server `_: + + .. code-block:: python + + # First run `npm install -g @playwright/mcp@latest` to install the MCP server. + import asyncio + from agentdhal_agentchat.agents import AssistantAgent + from agentdhal_agentchat.teams import RoundRobinGroupChat + from agentdhal_agentchat.conditions import TextMessageTermination + from agentdhal_agentchat.ui import Console + from agentdhal_extensions.models.openai import OpenAIChatCompletionClient + from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") + server_params = StdioServerParams( + command="npx", + args=[ + "@playwright/mcp@latest", + "--headless", + ], + ) + async with McpWorkbench(server_params) as mcp: + agent = AssistantAgent( + "web_browsing_assistant", + model_client=model_client, + workbench=mcp, + model_client_stream=True, + ) + team = RoundRobinGroupChat( + [agent], + termination_condition=TextMessageTermination(source="web_browsing_assistant"), + ) + await Console(team.run_stream(task="Find out how many contributors for the microsoft/autogen repository")) + + + asyncio.run(main()) + + """ + + component_provider_override = "agentdhal_extensions.tools.mcp.McpWorkbench" + component_config_schema = McpWorkbenchConfig + + def __init__( + self, + server_params: McpServerParams, + tool_overrides: Optional[Dict[str, ToolOverride]] = None, + model_client: ChatCompletionClient | None = None, + ) -> None: + self._server_params = server_params + self._tool_overrides = tool_overrides or {} + self._model_client = model_client + + # Build reverse mapping from override names to original names for call_tool + self._override_name_to_original: Dict[str, str] = {} + for original_name, override in self._tool_overrides.items(): + override_name = override.name + if override_name and override_name != original_name: + # Check for conflicts with other override names + if override_name in self._override_name_to_original: + existing_original = self._override_name_to_original[override_name] + raise ValueError( + f"Tool override name '{override_name}' is used by multiple tools: " + f"'{existing_original}' and '{original_name}'. Override names must be unique." + ) + self._override_name_to_original[override_name] = original_name + + # self._session: ClientSession | None = None + self._actor: McpSessionActor | None = None + self._actor_loop: asyncio.AbstractEventLoop | None = None + self._read = None + self._write = None + + @property + def server_params(self) -> McpServerParams: + return self._server_params + + async def list_tools(self) -> List[ToolSchema]: + if not self._actor: + await self.start() # fallback to start the actor if not initialized instead of raising an error + # Why? Because when deserializing the workbench, the actor might not be initialized yet. + # raise RuntimeError("Actor is not initialized. Call start() first.") + if self._actor is None: + raise RuntimeError("Actor is not initialized. Please check the server connection.") + result_future = await self._actor.call("list_tools", None) + list_tool_result = await result_future + assert isinstance( + list_tool_result, ListToolsResult + ), f"list_tools must return a CallToolResult, instead of : {str(type(list_tool_result))}" + schema: List[ToolSchema] = [] + for tool in list_tool_result.tools: + original_name = tool.name + name = original_name + description = tool.description or "" + + # Apply overrides if they exist for this tool + if original_name in self._tool_overrides: + override = self._tool_overrides[original_name] + if override.name is not None: + name = override.name + if override.description is not None: + description = override.description + + parameters = ParametersSchema( + type="object", + properties=tool.inputSchema.get("properties", {}), + required=tool.inputSchema.get("required", []), + additionalProperties=tool.inputSchema.get("additionalProperties", False), + ) + tool_schema = ToolSchema( + name=name, + description=description, + parameters=parameters, + ) + schema.append(tool_schema) + return schema + + async def call_tool( + self, + name: str, + arguments: Mapping[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + call_id: str | None = None, + ) -> ToolResult: + if not self._actor: + await self.start() # fallback to start the actor if not initialized instead of raising an error + # Why? Because when deserializing the workbench, the actor might not be initialized yet. + # raise RuntimeError("Actor is not initialized. Call start() first.") + if self._actor is None: + raise RuntimeError("Actor is not initialized. Please check the server connection.") + if not cancellation_token: + cancellation_token = CancellationToken() + if not arguments: + arguments = {} + + # Check if the name is an override name and map it back to the original + original_name = self._override_name_to_original.get(name, name) + + with trace_tool_span( + tool_name=name, # Use the requested name for tracing + tool_call_id=call_id, + ): + try: + result_future = await self._actor.call("call_tool", {"name": original_name, "kargs": arguments}) + cancellation_token.link_future(result_future) + result = await result_future + assert isinstance( + result, CallToolResult + ), f"call_tool must return a CallToolResult, instead of : {str(type(result))}" + result_parts: List[TextResultContent | ImageResultContent] = [] + is_error = result.isError + for content in result.content: + if isinstance(content, TextContent): + result_parts.append(TextResultContent(content=content.text)) + elif isinstance(content, ImageContent): + result_parts.append(ImageResultContent(content=Image.from_base64(content.data))) + elif isinstance(content, EmbeddedResource): + # TODO: how to handle embedded resources? + # For now we just use text representation. + result_parts.append(TextResultContent(content=content.model_dump_json())) + else: + raise ValueError(f"Unknown content type from server: {type(content)}") + except Exception as e: + error_message = self._format_errors(e) + is_error = True + result_parts = [TextResultContent(content=error_message)] + return ToolResult(name=name, result=result_parts, is_error=is_error) # Return the requested name + + @property + def initialize_result(self) -> Any: + if self._actor: + return self._actor.initialize_result + + return None + + async def list_prompts(self) -> ListPromptsResult: + """List available prompts from the MCP server.""" + if not self._actor: + await self.start() + if self._actor is None: + raise RuntimeError("Actor is not initialized. Please check the server connection.") + + result_future = await self._actor.call("list_prompts", None) + list_prompts_result = await result_future + assert isinstance( + list_prompts_result, ListPromptsResult + ), f"list_prompts must return a ListPromptsResult, instead of: {str(type(list_prompts_result))}" + + return list_prompts_result + + async def list_resources(self) -> ListResourcesResult: + """List available resources from the MCP server.""" + if not self._actor: + await self.start() + if self._actor is None: + raise RuntimeError("Actor is not initialized. Please check the server connection.") + + result_future = await self._actor.call("list_resources", None) + list_resources_result = await result_future + assert isinstance( + list_resources_result, ListResourcesResult + ), f"list_resources must return a ListResourcesResult, instead of: {str(type(list_resources_result))}" + + return list_resources_result + + async def list_resource_templates(self) -> ListResourceTemplatesResult: + """List available resource templates from the MCP server.""" + if not self._actor: + await self.start() + if self._actor is None: + raise RuntimeError("Actor is not initialized. Please check the server connection.") + + result_future = await self._actor.call("list_resource_templates", None) + list_templates_result = await result_future + assert isinstance( + list_templates_result, ListResourceTemplatesResult + ), f"list_resource_templates must return a ListResourceTemplatesResult, instead of: {str(type(list_templates_result))}" + + return list_templates_result + + async def read_resource(self, uri: str) -> ReadResourceResult: + """Read a resource from the MCP server.""" + if not self._actor: + await self.start() + if self._actor is None: + raise RuntimeError("Actor is not initialized. Please check the server connection.") + + result_future = await self._actor.call("read_resource", {"name": None, "kargs": {"uri": uri}}) + read_resource_result = await result_future + assert isinstance( + read_resource_result, ReadResourceResult + ), f"read_resource must return a ReadResourceResult, instead of: {str(type(read_resource_result))}" + + return read_resource_result + + async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> GetPromptResult: + """Get a prompt from the MCP server.""" + if not self._actor: + await self.start() + if self._actor is None: + raise RuntimeError("Actor is not initialized. Please check the server connection.") + + result_future = await self._actor.call("get_prompt", {"name": name, "kargs": {"arguments": arguments}}) + get_prompt_result = await result_future + assert isinstance( + get_prompt_result, GetPromptResult + ), f"get_prompt must return a GetPromptResult, instead of: {str(type(get_prompt_result))}" + + return get_prompt_result + + def _format_errors(self, error: Exception) -> str: + """Recursively format errors into a string.""" + + error_message = "" + if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup): + # ExceptionGroup is available in Python 3.11+. + # TODO: how to make this compatible with Python 3.10? + for sub_exception in error.exceptions: # type: ignore + error_message += self._format_errors(sub_exception) # type: ignore + else: + error_message += f"{str(error)}\n" + return error_message + + async def start(self) -> None: + if self._actor: + warnings.warn( + "McpWorkbench is already started. No need to start again.", + UserWarning, + stacklevel=2, + ) + return # Already initialized, no need to start again + + if isinstance(self._server_params, (StdioServerParams, SseServerParams, StreamableHttpServerParams)): + self._actor = McpSessionActor(self._server_params, model_client=self._model_client) + await self._actor.initialize() + self._actor_loop = asyncio.get_event_loop() + else: + raise ValueError(f"Unsupported server params type: {type(self._server_params)}") + + async def stop(self) -> None: + if self._actor: + # Close the actor + await self._actor.close() + self._actor = None + else: + raise RuntimeError("McpWorkbench is not started. Call start() first.") + + async def reset(self) -> None: + pass + + async def save_state(self) -> Mapping[str, Any]: + return McpWorkbenchState().model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + pass + + def _to_config(self) -> McpWorkbenchConfig: + model_client_config = None + if self._model_client is not None: + model_client_config = self._model_client.dump_component() + return McpWorkbenchConfig( + server_params=self._server_params, tool_overrides=self._tool_overrides, model_client=model_client_config + ) + + @classmethod + def _from_config(cls, config: McpWorkbenchConfig) -> Self: + model_client = None + if config.model_client is not None: + model_client = ChatCompletionClient.load_component(config.model_client) + return cls(server_params=config.server_params, tool_overrides=config.tool_overrides, model_client=model_client) + + def __del__(self) -> None: + # Ensure the actor is stopped when the workbench is deleted + # Use getattr to safely handle cases where attributes may not be set (e.g., if __init__ failed) + actor = getattr(self, "_actor", None) + actor_loop = getattr(self, "_actor_loop", None) + + if actor and actor_loop: + if actor_loop.is_running() and not actor_loop.is_closed(): + actor_loop.call_soon_threadsafe(lambda: asyncio.create_task(self.stop())) + else: + msg = "Cannot safely stop actor at [McpWorkbench.__del__]: loop is closed or not running" + warnings.warn(msg, RuntimeWarning, stacklevel=2) diff --git a/agent_dhal/agentdhal_extensions/tools/semantic_kernel/__init__.py b/agent_dhal/agentdhal_extensions/tools/semantic_kernel/__init__.py new file mode 100644 index 0000000..358a717 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/semantic_kernel/__init__.py @@ -0,0 +1,6 @@ +from ._kernel_function_from_tool import KernelFunctionFromTool, KernelFunctionFromToolSchema + +__all__ = [ + "KernelFunctionFromTool", + "KernelFunctionFromToolSchema", +] diff --git a/agent_dhal/agentdhal_extensions/tools/semantic_kernel/_kernel_function_from_tool.py b/agent_dhal/agentdhal_extensions/tools/semantic_kernel/_kernel_function_from_tool.py new file mode 100644 index 0000000..b520d40 --- /dev/null +++ b/agent_dhal/agentdhal_extensions/tools/semantic_kernel/_kernel_function_from_tool.py @@ -0,0 +1,94 @@ +from typing import Any, TypeVar + +from agentdhal_core import CancellationToken +from agentdhal_core.tools import BaseTool, ToolSchema +from pydantic import BaseModel + +from semantic_kernel.functions import KernelFunctionFromMethod, KernelFunctionFromPrompt, kernel_function +from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata +from semantic_kernel.prompt_template.input_variable import InputVariable +from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig + +InputT = TypeVar("InputT", bound=BaseModel) +OutputT = TypeVar("OutputT", bound=BaseModel) + + +class KernelFunctionFromTool(KernelFunctionFromMethod): + def __init__(self, tool: BaseTool[InputT, OutputT], plugin_name: str | None = None): + # Get the pydantic model types from the tool + args_type = tool.args_type() + return_type = tool.return_type() + + # 1) Define an async function that calls the tool + @kernel_function(name=tool.name, description=tool.description) + async def tool_method(**kwargs: dict[str, Any]) -> Any: + return await tool.run_json(kwargs, cancellation_token=CancellationToken()) + + # Parse schema for parameters + parameters_meta: list[KernelParameterMetadata] = [] + properties = tool.schema.get("parameters", {}).get("properties", {}) + + # Get the field types from the pydantic model + field_types = args_type.model_fields + + for prop_name, prop_info in properties.items(): + assert prop_name in field_types, f"Property {prop_name} not found in Tool {tool.name}" + assert isinstance(prop_info, dict), f"Property {prop_name} is not a dict in Tool {tool.name}" + + # Get the actual type from the pydantic model field + field_type = field_types[prop_name] + parameters_meta.append( + KernelParameterMetadata( + name=prop_name, + description=field_type.description or "", + default_value=field_type.get_default(), + type=prop_info.get("type", "string"), # type: ignore + type_object=field_type.annotation, + is_required=field_type.is_required(), + ) + ) + + # Create return parameter metadata + return_parameter = KernelParameterMetadata( + name="return", + description=f"Result from '{tool.name}' tool", + default_value=None, + type="object" if issubclass(return_type, BaseModel) else "string", + type_object=return_type, + is_required=True, + ) + + # Initialize the parent class + super().__init__( + method=tool_method, + plugin_name=plugin_name, + parameters=parameters_meta, + return_parameter=return_parameter, + additional_metadata=None, + ) + + self._tool = tool + + +class KernelFunctionFromToolSchema(KernelFunctionFromPrompt): + def __init__(self, tool_schema: ToolSchema, plugin_name: str | None = None): + properties = tool_schema.get("parameters", {}).get("properties", {}) + required = properties.get("required", []) + + prompt_template_config = PromptTemplateConfig( + name=tool_schema.get("name", ""), + description=tool_schema.get("description", ""), + input_variables=[ + InputVariable( + name=prop_name, description=prop_info.get("description", ""), is_required=prop_name in required + ) + for prop_name, prop_info in properties.items() + ], + ) + + super().__init__( + function_name=tool_schema.get("name", ""), + plugin_name=plugin_name, + description=tool_schema.get("description", ""), + prompt_template_config=prompt_template_config, + ) diff --git a/agent_dhal/agentdhal_extensions/ui/__init__.py b/agent_dhal/agentdhal_extensions/ui/__init__.py new file mode 100644 index 0000000..d80224a --- /dev/null +++ b/agent_dhal/agentdhal_extensions/ui/__init__.py @@ -0,0 +1,7 @@ +""" +This module implements utility classes for formatting/printing agent messages. +""" + +from ._rich_console import RichConsole + +__all__ = ["RichConsole"] diff --git a/agent_dhal/agentdhal_extensions/ui/_rich_console.py b/agent_dhal/agentdhal_extensions/ui/_rich_console.py new file mode 100644 index 0000000..b2f92df --- /dev/null +++ b/agent_dhal/agentdhal_extensions/ui/_rich_console.py @@ -0,0 +1,223 @@ +import asyncio +import os +import sys +import time +from typing import ( + AsyncGenerator, + Awaitable, + List, + Optional, + Tuple, + TypeVar, + cast, +) + +from agentdhal_agentchat.base import Response, TaskResult +from agentdhal_agentchat.messages import ( + BaseAgentEvent, + BaseChatMessage, + ModelClientStreamingChunkEvent, + MultiModalMessage, + UserInputRequestedEvent, +) +from agentdhal_agentchat.ui._console import UserInputManager +from agentdhal_core import Image +from agentdhal_core.models import RequestUsage +from rich.align import AlignMethod +from rich.console import Console +from rich.panel import Panel + +AGENT_COLORS = { + "user": "bright_green", + "MagenticOneOrchestrator": "bright_blue", + "WebSurfer": "bright_yellow", + "FileSurfer": "bright_cyan", + "Coder": "bright_magenta", + "Executor": "bright_red", +} +DEFAULT_AGENT_COLOR = "white" + +AGENT_ALIGNMENTS: dict[str, AlignMethod] = {"user": "right", "MagenticOneOrchestrator": "center"} +DEFAULT_AGENT_ALIGNMENT: AlignMethod = "left" + + +def _is_running_in_iterm() -> bool: + return os.getenv("TERM_PROGRAM") == "iTerm.app" + + +def _is_output_a_tty() -> bool: + return sys.stdout.isatty() + + +T = TypeVar("T", bound=TaskResult | Response) + + +def aprint(output: str, end: str = "\n") -> Awaitable[None]: + return asyncio.to_thread(print, output, end=end) + + +def _extract_message_content(message: BaseAgentEvent | BaseChatMessage) -> Tuple[List[str], List[Image]]: + if isinstance(message, MultiModalMessage): + text_parts = [item for item in message.content if isinstance(item, str)] + image_parts = [item for item in message.content if isinstance(item, Image)] + else: + text_parts = [message.to_text()] + image_parts = [] + return text_parts, image_parts + + +async def _aprint_panel(console: Console, text: str, title: str) -> None: + color = AGENT_COLORS.get(title, DEFAULT_AGENT_COLOR) + title_align = AGENT_ALIGNMENTS.get(title, DEFAULT_AGENT_ALIGNMENT) + + await asyncio.to_thread( + console.print, + Panel( + text, + title=title, + title_align=title_align, + border_style=color, + ), + ) + + +async def _aprint_message_content( + console: Console, + text_parts: List[str], + image_parts: List[Image], + source: str, + *, + render_image_iterm: bool = False, +) -> None: + if text_parts: + await _aprint_panel(console, "\n".join(text_parts), source) + + for img in image_parts: + if render_image_iterm: + await aprint(_image_to_iterm(img)) + else: + await aprint("\n") + + +async def RichConsole( + stream: AsyncGenerator[BaseAgentEvent | BaseChatMessage | T, None], + *, + no_inline_images: bool = False, + output_stats: bool = False, + user_input_manager: UserInputManager | None = None, +) -> T: + """ + Consumes the message stream from :meth:`~agentdhal_agentchat.base.TaskRunner.run_stream` + or :meth:`~agentdhal_agentchat.base.ChatAgent.on_messages_stream` and renders the messages to the console. + Returns the last processed TaskResult or Response. + + .. note:: + + `output_stats` is experimental and the stats may not be accurate. + It will be improved in future releases. + + Args: + stream (AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None] | AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]): Message stream to render. + This can be from :meth:`~agentdhal_agentchat.base.TaskRunner.run_stream` or :meth:`~agentdhal_agentchat.base.ChatAgent.on_messages_stream`. + no_inline_images (bool, optional): If terminal is iTerm2 will render images inline. Use this to disable this behavior. Defaults to False. + output_stats (bool, optional): (Experimental) If True, will output a summary of the messages and inline token usage info. Defaults to False. + + Returns: + last_processed: A :class:`~agentdhal_agentchat.base.TaskResult` if the stream is from :meth:`~agentdhal_agentchat.base.TaskRunner.run_stream` + or a :class:`~agentdhal_agentchat.base.Response` if the stream is from :meth:`~agentdhal_agentchat.base.ChatAgent.on_messages_stream`. + """ + render_image_iterm = _is_running_in_iterm() and _is_output_a_tty() and not no_inline_images + start_time = time.time() + total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + rich_console = Console() + + last_processed: Optional[T] = None + + async for message in stream: + if isinstance(message, TaskResult): + duration = time.time() - start_time + if output_stats: + output = ( + f"Number of messages: {len(message.messages)}\n" + f"Finish reason: {message.stop_reason}\n" + f"Total prompt tokens: {total_usage.prompt_tokens}\n" + f"Total completion tokens: {total_usage.completion_tokens}\n" + f"Duration: {duration:.2f} seconds\n" + ) + await _aprint_panel(rich_console, output, "Summary") + + last_processed = message # type: ignore + + elif isinstance(message, Response): + duration = time.time() - start_time + + # Print final response. + text_parts, image_parts = _extract_message_content(message.chat_message) + if message.chat_message.models_usage: + if output_stats: + text_parts.append( + f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]" + ) + total_usage.completion_tokens += message.chat_message.models_usage.completion_tokens + total_usage.prompt_tokens += message.chat_message.models_usage.prompt_tokens + + await _aprint_message_content( + rich_console, + text_parts, + image_parts, + message.chat_message.source, + render_image_iterm=render_image_iterm, + ) + + # Print summary. + if output_stats: + num_inner_messages = len(message.inner_messages) if message.inner_messages is not None else 0 + output = ( + f"Number of inner messages: {num_inner_messages}\n" + f"Total prompt tokens: {total_usage.prompt_tokens}\n" + f"Total completion tokens: {total_usage.completion_tokens}\n" + f"Duration: {duration:.2f} seconds\n" + ) + await _aprint_panel(rich_console, output, "Summary") + + # mypy ignore + last_processed = message # type: ignore + # We don't want to print UserInputRequestedEvent messages, we just use them to signal the user input event. + elif isinstance(message, UserInputRequestedEvent): + if user_input_manager is not None: + user_input_manager.notify_event_received(message.request_id) + elif isinstance(message, ModelClientStreamingChunkEvent): + # TODO: Handle model client streaming chunk events. + pass + else: + # Cast required for mypy to be happy + message = cast(BaseAgentEvent | BaseChatMessage, message) # type: ignore + + text_parts, image_parts = _extract_message_content(message) + # Add usage stats if needed + if message.models_usage: + if output_stats: + text_parts.append( + f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]" + ) + total_usage.completion_tokens += message.models_usage.completion_tokens + total_usage.prompt_tokens += message.models_usage.prompt_tokens + + await _aprint_message_content( + rich_console, + text_parts, + image_parts, + message.source, + render_image_iterm=render_image_iterm, + ) + + if last_processed is None: + raise ValueError("No TaskResult or Response was processed.") + + return last_processed + + +# iTerm2 image rendering protocol: https://iterm2.com/documentation-images.html +def _image_to_iterm(image: Image) -> str: + image_data = image.to_base64() + return f"\033]1337;File=inline=1:{image_data}\a\n" diff --git a/agent_dhal/hal.py b/agent_dhal/hal.py new file mode 100644 index 0000000..4a07871 --- /dev/null +++ b/agent_dhal/hal.py @@ -0,0 +1,944 @@ +#!/usr/bin/env python3 +""" +Hal - Primary AI Agent for AgentDhal Framework + +Hal is the main AI assistant agent providing: +- Multi-turn conversations +- Function/tool calling capabilities +- Code execution and analysis +- Team collaboration +- Memory and context management +- Customizable behavior and prompts + +Based on AutoGen AssistantAgent with DarkHal-specific enhancements. +""" + +import asyncio +import subprocess +import sys +import os +import json +import requests +from typing import Any, Dict, List, Optional, Callable, Union +from dataclasses import dataclass +import platform +import shutil +import shlex +import ctypes +import tempfile + +# Import actual working AgentDhal components +from .agentdhal_core import ( + Agent, + AgentId, + MessageContext, + RoutedAgent, + message_handler, + default_subscription, + SingleThreadedAgentRuntime +) + +try: + from .agentdhal_core.models import ( + ChatCompletionClient, + LLMMessage, + SystemMessage, + UserMessage, + ) +except ImportError: + # Define basic message types if not available + class LLMMessage: + def __init__(self, content: str): + self.content = content + + class SystemMessage(LLMMessage): + def __init__(self, content: str): + super().__init__(content) + self.role = "system" + + class UserMessage(LLMMessage): + def __init__(self, content: str): + super().__init__(content) + self.role = "user" + + # Define ChatCompletionClient as a basic class + class ChatCompletionClient: + pass + +try: + from .agentdhal_core.tools import FunctionTool, Tool +except ImportError: + # Create working tool implementation compatible with add_function(func, description) + class FunctionTool: + def __init__(self, func: Callable, description: str = ""): + self.func = func + self.description = description + Tool = FunctionTool # alias for typing + + +class HalModelClient: + """Model client that integrates with DarkHal's LLM runtime.""" + + def __init__(self, model_name: str = "gpt-4"): + self.model_name = model_name + self.llm_model = None + self._initialize_model() + + def _initialize_model(self): + """Initialize the LLM model using DarkHal's runtime.""" + try: + # Import from the main application's LLM runtime + sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + from llm_runtime import load_model + + # Load model using existing runtime + self.llm_model = load_model( + source=self.model_name, + device="auto", + quantization="4bit" + ) + print(f"[Hal] Initialized model: {self.model_name}") + + except Exception as e: + print(f"[Hal] Could not initialize model {self.model_name}: {e}") + self.llm_model = None + + async def create_chat_completion(self, messages: List[LLMMessage], **kwargs): + """Create chat completion using the loaded model (supports streaming).""" + try: + if not self.llm_model: + raise Exception("No model loaded") + + # Convert messages to text format + conversation = "" + for msg in messages: + if hasattr(msg, 'role') and hasattr(msg, 'content'): + if msg.role == "system": + conversation += f"System: {msg.content}\n\n" + elif msg.role == "user": + conversation += f"User: {msg.content}\n\n" + elif msg.role == "assistant": + conversation += f"Assistant: {msg.content}\n\n" + else: + conversation += f"{msg.content}\n\n" + + conversation += "Assistant: " + + # Callbacks and options + stream = bool(kwargs.get('stream', False)) + on_token = kwargs.get('on_token') + on_complete = kwargs.get('on_complete') + on_error = kwargs.get('on_error') + temperature = kwargs.get('temperature', 0.7) + max_tokens = kwargs.get('max_tokens', 2000) + + # Try to build a config if available (optional) + cfg = None + try: + from llm_runtime import GenerateConfig # type: ignore + cfg = GenerateConfig( + max_tokens=max_tokens, + temperature=temperature + ) + except Exception: + cfg = None + + full_text = "" + + # Streaming if supported + if stream and hasattr(self.llm_model, 'stream'): + try: + iterator = self.llm_model.stream(conversation, cfg) if cfg is not None else self.llm_model.stream(conversation) + for chunk in iterator: + if chunk: + full_text += str(chunk) + if on_token: + try: + on_token("request-1", str(chunk)) + except Exception: + pass + if on_complete: + try: + on_complete("request-1", full_text, {"finish_reason": "stop"}) + except Exception: + pass + except Exception as se: + if on_error: + try: + on_error("request-1", se) + except Exception: + pass + raise + + else: + # Non-streaming path + if hasattr(self.llm_model, 'generate'): + full_text = self.llm_model.generate(conversation, cfg) if cfg is not None else self.llm_model.generate(conversation) + elif hasattr(self.llm_model, '__call__'): + full_text = self.llm_model(conversation) + else: + # Fallback for different model interfaces + full_text = str(self.llm_model) + if on_complete: + try: + on_complete("request-1", full_text, {"finish_reason": "stop"}) + except Exception: + pass + + # Create response object + class CompletionResponse: + def __init__(self, content): + self.content = content + self.function_calls = None + + return CompletionResponse(full_text) + + except Exception as e: + print(f"[Hal] Error generating response: {e}") + + # Return error response + class CompletionResponse: + def __init__(self, content): + self.content = content + self.function_calls = None + + if kwargs.get('on_error'): + try: + kwargs['on_error']("request-1", e) + except Exception: + pass + return CompletionResponse(f"I apologize, I encountered an error: {str(e)}") + + def is_available(self) -> bool: + """Check if the model client is available.""" + return self.llm_model is not None + + +@dataclass +class DhalConfig: + """Configuration for Hal agent.""" + name: str = "Hal" + system_message: str = "You are Hal, an advanced AI assistant integrated into DarkHal 2.0. You help users with coding, analysis, security testing, and general AI tasks." + model: str = "gpt-4" + temperature: float = 0.7 + max_tokens: int = 2000 + tools: List[FunctionTool] = None # type: ignore + memory_limit: int = 10000 + + def __post_init__(self): + if self.tools is None: + self.tools = [] + + +class Dhal(RoutedAgent): + """ + Dhal - The primary AI agent for DarkHal 2.0 + + Dhal provides advanced conversational AI capabilities with: + - Natural language understanding and generation + - Function calling and tool integration + - Code execution and analysis + - Memory and context management + - Team collaboration capabilities + """ + + def __init__( + self, + config: DhalConfig, + model_client: ChatCompletionClient, + agent_id: Optional[AgentId] = None + ): + """Initialize Dhal agent.""" + if agent_id is None: + agent_id = AgentId(config.name, "dhal") + + super().__init__(config.name) + + self.config = config + self.model_client = model_client + self.agent_id = agent_id + + # Initialize conversation memory + self.conversation_history: List[LLMMessage] = [] + if config.system_message: + self.conversation_history.append(SystemMessage(content=config.system_message)) + + # Tools and functions + self.tools = config.tools or [] + self.function_map: Dict[str, Callable] = {} + + # Agent state + self.is_active = False + self.current_task = None + + self._register_default_tools() + + def _register_default_tools(self): + """Register default tools available to Hal.""" + + # Code execution tool + def execute_python(code: str) -> str: + """Execute Python code and return the result.""" + try: + import tempfile + import subprocess + + # Create a temporary file to execute the code + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(code) + temp_file = f.name + + try: + # Execute the code and capture output + result = subprocess.run([sys.executable, temp_file], + capture_output=True, text=True, timeout=30, cwd=os.getcwd()) + + output = "" + if result.stdout: + output += f"Output:\n{result.stdout}" + if result.stderr: + output += f"\nErrors:\n{result.stderr}" + if result.returncode != 0: + output += f"\nReturn code: {result.returncode}" + + return output or "Code executed successfully (no output)" + + finally: + # Always clean up temp file + try: + os.unlink(temp_file) + except: + pass + + except subprocess.TimeoutExpired: + return "Error: Code execution timed out (30 second limit)" + except Exception as e: + return f"Error executing code: {str(e)}" + + # Web search tool + def web_search(query: str) -> str: + """Search the web for information using DuckDuckGo API.""" + try: + import requests + + # Use DuckDuckGo instant answer API + url = "https://api.duckduckgo.com/" + params = { + 'q': query, + 'format': 'json', + 'no_html': '1', + 'skip_disambig': '1' + } + + response = requests.get(url, params=params, timeout=10) + response.raise_for_status() + + data = response.json() + + result = f"Search query: {query}\n\n" + + # Extract useful information + if data.get('AbstractText'): + result += f"Summary: {data['AbstractText']}\n\n" + + if data.get('RelatedTopics'): + result += "Related information:\n" + for i, topic in enumerate(data['RelatedTopics'][:5]): # Limit to first 5 + if isinstance(topic, dict) and 'Text' in topic: + result += f"{i + 1}. {topic['Text']}\n" + elif isinstance(topic, dict) and 'Topics' in topic: + # Handle nested topics + for subtopic in topic['Topics'][:2]: + if 'Text' in subtopic: + result += f"{i + 1}. {subtopic['Text']}\n" + + if data.get('Answer'): + result += f"\nDirect answer: {data['Answer']}\n" + + if data.get('Definition'): + result += f"\nDefinition: {data['Definition']}\n" + + return result if result.strip() != f"Search query: {query}" else f"No specific results found for: {query}" + + except Exception as e: + return f"Error performing web search: {str(e)}" + + # Unrestricted file read + def read_file(filepath: str) -> str: + try: + abs_path = os.path.abspath(filepath) + with open(abs_path, 'r', encoding='utf-8') as f: + content = f.read() + return f"File: {abs_path}\nSize: {len(content)} characters\n\n{content}" + except UnicodeDecodeError: + # Fallback for binary files: return first 4096 bytes + try: + with open(abs_path, 'rb') as f: + data = f.read(4096) + return f"Binary file: {abs_path}\nFirst 4096 bytes:\n{data}" + except Exception as e: + return f"Error reading binary file: {abs_path}\n{e}" + except Exception as e: + return f"Error reading file: {abs_path}\n{e}" + + # Unrestricted file write (creates parent dirs) + def write_file(filepath: str, content: str) -> str: + try: + abs_path = os.path.abspath(filepath) + os.makedirs(os.path.dirname(abs_path) or ".", exist_ok=True) + with open(abs_path, 'w', encoding='utf-8') as f: + f.write(content) + return f"Wrote {len(content)} bytes to {abs_path}" + except Exception as e: + return f"Error writing file: {abs_path}\n{e}" + + # Unrestricted directory list + def list_files(directory: str = ".") -> str: + try: + abs_dir = os.path.abspath(directory) + items = [] + for name in sorted(os.listdir(abs_dir)): + p = os.path.join(abs_dir, name) + if os.path.isdir(p): + items.append(f"[DIR] {name}/") + else: + items.append(f"[FILE] {name} ({os.path.getsize(p)} bytes)") + return f"Directory: {abs_dir}\n\n" + "\n".join(items) + except Exception as e: + return f"Error listing {directory}\n{e}" + + # Unrestricted shell command (optionally with cwd/timeout) + def run_shell_command(command: str, cwd: str = None, timeout: int = 300) -> str: + try: + r = subprocess.run( + command, + shell=True, + cwd=cwd, + capture_output=True, + text=True, + timeout=timeout + ) + out = r.stdout + if r.stderr: + out += f"\n[stderr]\n{r.stderr}" + out += f"\n[exit_code] {r.returncode}" + return out.strip() + except subprocess.TimeoutExpired: + return f"Error: timed out after {timeout}s" + except Exception as e: + return f"Error: {e}" + + # --- NEW: PowerShell (Windows), Bash (Linux/macOS), and Ruby runner --- + def powershell(command: str, cwd: str = None, timeout: int = 120) -> str: + """Run a PowerShell command on Windows (prefers pwsh, falls back to powershell).""" + exe = shutil.which("pwsh") or shutil.which("powershell") + if not exe: + return "Error: PowerShell not found" + cmd = [exe, "-NoProfile", "-NonInteractive", "-ExecutionPolicy", "Bypass", "-Command", command] + r = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout) + out = r.stdout + if r.stderr: + out += f"\n[stderr]\n{r.stderr}" + out += f"\n[exit_code] {r.returncode}" + return out.strip() + + def bash(command: str, cwd: str = None, timeout: int = 120) -> str: + """Run a Bash command on Linux/macOS.""" + exe = shutil.which("bash") + if not exe: + return "Error: bash not found" + cmd = [exe, "-lc", command] + r = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout) + out = r.stdout + if r.stderr: + out += f"\n[stderr]\n{r.stderr}" + out += f"\n[exit_code] {r.returncode}" + return out.strip() + + def ruby_run(script_path: str, args: str = "", cwd: str = None, timeout: int = 180, + use_bundler: bool = False) -> str: + """Run a Ruby script. If use_bundler=True and Gemfile+bundle present → `bundle exec ruby`.""" + ruby = shutil.which("ruby") + if not ruby: + return "Error: ruby not found in PATH" + argv = shlex.split(args) if args else [] + if use_bundler and shutil.which("bundle") and os.path.exists(os.path.join(cwd or os.getcwd(), "Gemfile")): + cmd = ["bundle", "exec", ruby, script_path] + argv + else: + cmd = [ruby, script_path] + argv + r = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout) + out = r.stdout + if r.stderr: + out += f"\n[stderr]\n{r.stderr}" + out += f"\n[exit_code] {r.returncode}" + return out.strip() + + # --- NEW: Elevated commands --- + def powershell_admin(command: str, cwd: str = None, timeout: int = 300) -> str: + """ + Run an elevated (UAC) PowerShell command on Windows. + Captures output by running a temporary elevated script and reading its output file. + """ + exe = shutil.which("pwsh") or shutil.which("powershell") + if not exe: + return "Error: PowerShell not found" + if platform.system() != "Windows": + return "Error: Admin PowerShell is Windows-only" + + # Prepare temp script and output file + with tempfile.TemporaryDirectory() as td: + out_file = os.path.join(td, "out.txt") + script_file = os.path.join(td, "script.ps1") + # PS script: run the user's command and redirect ALL streams to file + script = f"$ErrorActionPreference='Continue'; & {{ {command} }} *> '{out_file}'" + with open(script_file, "w", encoding="utf-8") as f: + f.write(script) + + # Build a PS command that elevates a new PowerShell instance to run the script + # We pass '-File ""' to the elevated process and wait for it. + args_str = f"-NoProfile -NonInteractive -ExecutionPolicy Bypass -File \"{script_file}\"" + ps_cmd = f"Start-Process -Verb RunAs -FilePath '{exe}' -ArgumentList '{args_str}' -Wait" + + # Launch the elevation prompt + r = subprocess.run( + [exe, "-NoProfile", "-NonInteractive", "-ExecutionPolicy", "Bypass", "-Command", ps_cmd], + cwd=cwd, capture_output=True, text=True, timeout=timeout + ) + + # Read output written by elevated process + out = "" + try: + if os.path.exists(out_file): + with open(out_file, "r", encoding="utf-8", errors="replace") as f: + out = f.read().strip() + except Exception as e: + out += f"\n[read_error] {e}" + + if r.stderr: + out += f"\n[launcher_stderr]\n{r.stderr.strip()}" + out += f"\n[launcher_exit_code] {r.returncode}" + return out.strip() + + def sudo(command: str, cwd: str = None, timeout: int = 300) -> str: + """ + Run a command with elevated privileges on Linux (and most *nix). + Prefers pkexec (GUI polkit prompt); falls back to sudo (TTY or policykit may prompt). + """ + if platform.system() == "Windows": + return "Error: sudo is not available on Windows" + + if shutil.which("pkexec"): + cmd = ["pkexec", "bash", "-lc", command] + elif shutil.which("sudo"): + # This will require a TTY or desktop prompt depending on environment + cmd = ["sudo", "bash", "-lc", command] + else: + return "Error: Neither pkexec nor sudo found" + + r = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout) + out = r.stdout + if r.stderr: + out += f"\n[stderr]\n{r.stderr}" + out += f"\n[exit_code] {r.returncode}" + return out.strip() + + # --- Mouse control (Windows + Linux via xdotool) --- + def mouse_move(x: int, y: int) -> str: + """Move mouse to absolute screen coordinates (x, y).""" + sysname = platform.system() + if sysname == "Windows": + ctypes.windll.user32.SetCursorPos(int(x), int(y)) + return "ok" + elif sysname in ("Linux", "FreeBSD"): + if not shutil.which("xdotool"): + return "Error: xdotool not found (Wayland may not support it)" + r = subprocess.run(["xdotool", "mousemove", str(int(x)), str(int(y))], + capture_output=True, text=True) + return "ok" if r.returncode == 0 else (r.stderr or r.stdout) + else: + return "Unsupported platform" + + def mouse_click(button: str = "left") -> str: + """Click mouse button: left|middle|right.""" + sysname = platform.system() + if sysname == "Windows": + btn = button.lower() + if btn == "left": + down, up = 0x0002, 0x0004 + elif btn == "right": + down, up = 0x0008, 0x0010 + elif btn == "middle": + down, up = 0x0020, 0x0040 + else: + return "Error: unknown button" + ctypes.windll.user32.mouse_event(down, 0, 0, 0, 0) + ctypes.windll.user32.mouse_event(up, 0, 0, 0, 0) + return "ok" + elif sysname in ("Linux", "FreeBSD"): + if not shutil.which("xdotool"): + return "Error: xdotool not found" + map_btn = {"left": "1", "middle": "2", "right": "3"} + code = map_btn.get(button.lower()) + if not code: + return "Error: unknown button" + r = subprocess.run(["xdotool", "click", code], capture_output=True, text=True) + return "ok" if r.returncode == 0 else (r.stderr or r.stdout) + else: + return "Unsupported platform" + + def mouse_scroll(lines: int = 1) -> str: + """Scroll vertically: positive = up, negative = down.""" + sysname = platform.system() + if sysname == "Windows": + # 120 units per wheel click + delta = int(lines) * 120 + ctypes.windll.user32.mouse_event(0x0800, 0, 0, delta, 0) + return "ok" + elif sysname in ("Linux", "FreeBSD"): + if not shutil.which("xdotool"): + return "Error: xdotool not found" + clicks = abs(int(lines)) + button = "4" if int(lines) > 0 else "5" + for _ in range(clicks): + subprocess.run(["xdotool", "click", button], capture_output=True) + return "ok" + else: + return "Unsupported platform" + + # ========================= + # WINDOWS XINPUT (DirectX) + # ========================= + # Tools: xinput_list, xinput_get_state, xinput_set_vibration + def _load_xinput_dll(): + if platform.system() != "Windows": + return None + for dll in ("XInput1_4.dll", "XInput1_3.dll", "XInput9_1_0.dll"): + try: + return ctypes.WinDLL(dll) + except OSError: + continue + return None + + # Structures from XInput + class XINPUT_GAMEPAD(ctypes.Structure): + _fields_ = [ + ("wButtons", ctypes.c_ushort), + ("bLeftTrigger", ctypes.c_ubyte), + ("bRightTrigger", ctypes.c_ubyte), + ("sThumbLX", ctypes.c_short), + ("sThumbLY", ctypes.c_short), + ("sThumbRX", ctypes.c_short), + ("sThumbRY", ctypes.c_short), + ] + + class XINPUT_STATE(ctypes.Structure): + _fields_ = [ + ("dwPacketNumber", ctypes.c_uint), + ("Gamepad", XINPUT_GAMEPAD), + ] + + class XINPUT_VIBRATION(ctypes.Structure): + _fields_ = [ + ("wLeftMotorSpeed", ctypes.c_ushort), + ("wRightMotorSpeed", ctypes.c_ushort), + ] + + # Button bitmasks + XINPUT_BUTTONS = { + "DPAD_UP": 0x0001, + "DPAD_DOWN": 0x0002, + "DPAD_LEFT": 0x0004, + "DPAD_RIGHT": 0x0008, + "START": 0x0010, + "BACK": 0x0020, + "LEFT_THUMB": 0x0040, + "RIGHT_THUMB": 0x0080, + "LEFT_SHOULDER": 0x0100, + "RIGHT_SHOULDER": 0x0200, + "A": 0x1000, + "B": 0x2000, + "X": 0x4000, + "Y": 0x8000, + } + + def xinput_list() -> str: + """List connected XInput (Xbox-compatible) controllers on Windows.""" + if platform.system() != "Windows": + return "Unsupported platform (XInput is Windows/DirectX)" + dll = _load_xinput_dll() + if not dll: + return "Error: XInput DLL not found" + + XInputGetState = dll.XInputGetState + XInputGetState.argtypes = [ctypes.c_uint, ctypes.POINTER(XINPUT_STATE)] + XInputGetState.restype = ctypes.c_uint # 0 = success + + found = [] + for i in range(4): + state = XINPUT_STATE() + rc = XInputGetState(i, ctypes.byref(state)) + found.append({"id": i, "connected": (rc == 0)}) + + return json.dumps(found) + + def xinput_get_state(controller_id: int = 0) -> str: + """Get state for a specific controller id (0-3). Returns JSON with buttons/axes/triggers.""" + if platform.system() != "Windows": + return "Unsupported platform" + dll = _load_xinput_dll() + if not dll: + return "Error: XInput DLL not found" + + XInputGetState = dll.XInputGetState + XInputGetState.argtypes = [ctypes.c_uint, ctypes.POINTER(XINPUT_STATE)] + XInputGetState.restype = ctypes.c_uint + + state = XINPUT_STATE() + rc = XInputGetState(int(controller_id), ctypes.byref(state)) + if rc != 0: + return f"Error: controller {controller_id} not connected" + + gp = state.Gamepad + buttons = gp.wButtons + pressed = [name for name, mask in XINPUT_BUTTONS.items() if buttons & mask] + + data = { + "id": int(controller_id), + "packet": state.dwPacketNumber, + "buttons": pressed, + "left_trigger": int(gp.bLeftTrigger), + "right_trigger": int(gp.bRightTrigger), + "thumb_lx": int(gp.sThumbLX), + "thumb_ly": int(gp.sThumbLY), + "thumb_rx": int(gp.sThumbRX), + "thumb_ry": int(gp.sThumbRY), + } + return json.dumps(data) + + def xinput_set_vibration(controller_id: int = 0, + left_motor: int = 0, + right_motor: int = 0) -> str: + """Set rumble (0-65535 per motor).""" + if platform.system() != "Windows": + return "Unsupported platform" + dll = _load_xinput_dll() + if not dll: + return "Error: XInput DLL not found" + + XInputSetState = dll.XInputSetState + XInputSetState.argtypes = [ctypes.c_uint, ctypes.POINTER(XINPUT_VIBRATION)] + XInputSetState.restype = ctypes.c_uint + + vib = XINPUT_VIBRATION( + wLeftMotorSpeed=max(0, min(65535, int(left_motor))), + wRightMotorSpeed=max(0, min(65535, int(right_motor))) + ) + rc = XInputSetState(int(controller_id), ctypes.byref(vib)) + return "ok" if rc == 0 else f"Error: rc={rc}" + + # Register tools + self.add_function("execute_python", execute_python, "Execute Python code and return results") + self.add_function("web_search", web_search, "Search the web for information using DuckDuckGo") + self.add_function("read_file", read_file, "Read contents of a file") + self.add_function("write_file", write_file, "Write content to a file") + self.add_function("list_files", list_files, "List files and directories") + self.add_function("run_shell_command", run_shell_command, "Run shell commands") + # NEW registrations (standard shells and Ruby) + self.add_function("powershell", powershell, "Run a PowerShell command on Windows") + self.add_function("bash", bash, "Run a Bash command on Linux/macOS") + self.add_function("ruby_run", ruby_run, "Run a Ruby script (optionally via bundler)") + # NEW registrations (elevated) + self.add_function("powershell_admin", powershell_admin, "Run elevated PowerShell with UAC; captures output") + self.add_function("sudo", sudo, "Run a command with sudo or pkexec; user must approve") + # NEW registrations (input control) + self.add_function("mouse_move", mouse_move, "Move mouse to absolute screen coordinates") + self.add_function("mouse_click", mouse_click, "Click mouse button: left|middle|right") + self.add_function("mouse_scroll", mouse_scroll, "Scroll vertically by N lines (±)") + # Windows XInput (DirectX) gamepad tools + self.add_function("xinput_list", xinput_list, "List XInput controllers (Windows/DirectX)") + self.add_function("xinput_get_state", xinput_get_state, "Get XInput state for controller id (0-3)") + self.add_function("xinput_set_vibration", xinput_set_vibration, "Set XInput vibration (rumble)") + + def add_function(self, name: str, func: Callable, description: str): + """Add a function tool to Hal's capabilities.""" + self.function_map[name] = func + tool = FunctionTool(func, description=description) + self.tools.append(tool) + + @message_handler + async def handle_user_message(self, message: str, ctx: MessageContext) -> str: + """Handle user messages and generate responses.""" + try: + # Add user message to conversation history + user_msg = UserMessage(content=message) + self.conversation_history.append(user_msg) + + # Prepare messages for model + messages = self._prepare_messages() + + # Generate response using model client + response = await self.model_client.create_chat_completion( + messages=messages, + model=self.config.model, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + tools=self.tools if self.tools else None + ) + # Process response + assistant_message = response.content + # Handle function calls if present + if hasattr(response, 'function_calls') and response.function_calls: + assistant_message = await self._handle_function_calls(response.function_calls) + # Add assistant response to history + self.conversation_history.append(UserMessage(content=assistant_message)) + # Manage memory limits + self._manage_memory() + return assistant_message + except Exception as e: + error_msg = f"Hal encountered an error: {str(e)}" + self.conversation_history.append(UserMessage(content=error_msg)) + return error_msg + + def send_dhal_message( + self, + conversation_id: str, + message: str, + stream: bool = False, + on_token: Optional[Callable[[str, str], None]] = None, + on_complete: Optional[Callable[[str, str, Dict[str, Any]], None]] = None, + on_error: Optional[Callable[[str, Exception], None]] = None + ) -> None: + """ + Public entry to send a message to the agent with optional streaming callbacks. + Runs in a background thread to avoid blocking the UI thread. + """ + # Add user message to history + self.conversation_history.append(UserMessage(content=message)) + messages = self._prepare_messages() + + def _worker(): + import asyncio + + async def _run(): + try: + # Delegate to model client with callbacks + resp = await self.model_client.create_chat_completion( + messages=messages, + model=self.config.model, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + tools=self.tools if self.tools else None, + stream=stream, + on_token=(lambda req_id, delta: on_token and on_token(conversation_id, delta)), + on_complete=(lambda req_id, full, meta: ( + self.conversation_history.append(UserMessage(content=full)), + self._manage_memory(), + on_complete and on_complete(conversation_id, full, meta) + )), + on_error=(lambda req_id, err: on_error and on_error(conversation_id, err)) + ) + # If not streaming, we need to append and call on_complete here + if not stream and resp and hasattr(resp, "content"): + assistant_message = resp.content + self.conversation_history.append(UserMessage(content=assistant_message)) + self._manage_memory() + if on_complete: + on_complete(conversation_id, assistant_message, {"finish_reason": "stop"}) + except Exception as e: + if on_error: + on_error(conversation_id, e) + + asyncio.run(_run()) + + import threading as _threading + t = _threading.Thread(target=_worker, daemon=True) + t.start() + + async def _handle_function_calls(self, function_calls: List[Dict]) -> str: + """Handle function calls from the model.""" + results = [] + + for call in function_calls: + func_name = call.get('name') + func_args = call.get('arguments', {}) + + if func_name in self.function_map: + try: + if asyncio.iscoroutinefunction(self.function_map[func_name]): + result = await self.function_map[func_name](**func_args) + else: + result = self.function_map[func_name](**func_args) + + results.append(f"[{func_name}] {result}") + except Exception as e: + results.append(f"[{func_name}] Error: {str(e)}") + else: + results.append(f"[{func_name}] Function not found") + + return "\n".join(results) + + def _prepare_messages(self) -> List[LLMMessage]: + """Prepare messages for the model, respecting context limits.""" + # For now, return all messages. In production, implement smart truncation + return self.conversation_history.copy() + + def _manage_memory(self): + """Manage conversation memory to stay within limits.""" + if len(self.conversation_history) > self.config.memory_limit: + # Keep system message and recent messages + system_msgs = [msg for msg in self.conversation_history if isinstance(msg, SystemMessage)] + recent_msgs = self.conversation_history[-(self.config.memory_limit - len(system_msgs)):] + self.conversation_history = system_msgs + recent_msgs + + def get_status(self) -> Dict[str, Any]: + """Get current status of Hal agent.""" + return { + "name": self.config.name, + "active": self.is_active, + "model": self.config.model, + "conversation_length": len(self.conversation_history), + "available_tools": len(self.tools), + "current_task": self.current_task + } + + def reset_conversation(self): + """Reset conversation history while keeping system message.""" + system_msgs = [msg for msg in self.conversation_history if isinstance(msg, SystemMessage)] + self.conversation_history = system_msgs + + def update_config(self, **kwargs): + """Update Hal's configuration.""" + for key, value in kwargs.items(): + if hasattr(self.config, key): + setattr(self.config, key, value) + + +# Convenience function to create Hal agent +def create_dhal( + name: str = "Dhal", + system_message: str = None, + model: str = "gpt-4", + model_client: ChatCompletionClient = None, + **kwargs +) -> Dhal: + """Create a Dhal agent with specified configuration.""" + + if system_message is None: + system_message = f"You are {name}, an advanced AI assistant integrated into DarkHal 2.0. You help users with coding, analysis, security testing, and general AI tasks." + + config = DhalConfig( + name=name, + system_message=system_message, + model=model, + **kwargs + ) + + # Initialize model client if not provided + if model_client is None: + try: + model_client = HalModelClient(model_name=model) + except Exception as e: + print(f"[Dhal] Warning: Could not initialize model client: {e}") + model_client = None + + return Dhal(config, model_client) diff --git a/agent_dhal_integration.py b/agent_dhal_integration.py new file mode 100644 index 0000000..78d95c3 --- /dev/null +++ b/agent_dhal_integration.py @@ -0,0 +1,420 @@ +#!/usr/bin/env python3 +""" +Agent Dhal Integration for DarkHal 2.0 + +Connects the DarkAgent to the main application's UI and model loading system. +Provides thread-safe communication between the agent and the Tkinter interface. +""" + +import threading +import queue +import time +import json +from pathlib import Path +from typing import Optional, Dict, Any +from tkinter import messagebox +from agent_debug_tracer import get_tracer, trace +import __spy as spy # global announcer for current model + +try: + from agent_dhal.hal import Dhal as DarkAgent +except ImportError: + try: + from agent_dhal.hal import Hal as DarkAgent # fallback class name + except ImportError: + DarkAgent = None # type: ignore +AGENT_AVAILABLE = DarkAgent is not None + +class ExistingRuntimeClient: + """ + Thin adapter that wraps an already-loaded model object from the app and + exposes a chat-completion-like interface for the agent. + """ + def __init__(self, model): + self._model = model + + def is_available(self) -> bool: + return self._model is not None + + def create_chat_completion( + self, + messages, + temperature: float = 0.7, + max_tokens: int = 2048, + tools=None, + stream: bool = False, + on_token=None, + on_complete=None, + on_error=None, + top_p: float = 0.9, + ): + """ + Minimal bridge: + - Flattens chat messages into a single prompt. + - Uses model.stream(...) if available when stream=True; otherwise model.generate(...). + - Tries to construct a generation config if llm_runtime.GenerateConfig is available. + """ + try: + prompt = self._flatten_messages(messages) + config = self._build_config(max_tokens=max_tokens, temperature=temperature, top_p=top_p) + + if stream and hasattr(self._model, "stream"): + full = [] + for chunk in self._model.stream(prompt, config) if config is not None else self._model.stream(prompt): + if on_token: + try: + on_token("request-1", chunk) + except Exception: + pass + full.append(chunk) + final_text = "".join(full) + if on_complete: + on_complete("request-1", final_text, {"finish_reason": "stop"}) + return {"text": final_text, "finish_reason": "stop"} + + # Non-streaming path + if hasattr(self._model, "generate"): + text = self._model.generate(prompt, config) if config is not None else self._model.generate(prompt) + if on_complete: + on_complete("request-1", text, {"finish_reason": "stop"}) + return {"text": text, "finish_reason": "stop"} + + raise RuntimeError("Model does not support generate/stream") + + except Exception as e: + if on_error: + on_error("request-1", e) + else: + raise + + def _flatten_messages(self, messages) -> str: + # Simple readable flattening + lines = [] + for m in messages or []: + role = (m.get("role") if isinstance(m, dict) else getattr(m, "role", "user")).upper() + content = (m.get("content") if isinstance(m, dict) else getattr(m, "content", "")) or "" + lines.append(f"{role}: {content}") + return "\n".join(lines) if lines else "" + + def _build_config(self, max_tokens: int, temperature: float, top_p: float): + # Let the underlying model decide all defaults (no explicit config). + return None + + + +class DhalAgentIntegration: + """Integration layer between DarkAgent and DarkHal 2.0 UI.""" + + def __init__(self, agents_tab): + trace("INTEGRATION_INIT", "DhalAgentIntegration initializing") + self.agents_tab = agents_tab + self.agent: Optional[DarkAgent] = None + self.client: Optional[ExistingRuntimeClient] = None + self.conversation_id = "main_conversation" + self.tracer = get_tracer() + + # UI update queue for thread-safe communication + self.ui_queue = queue.Queue() + self.is_running = False + + # Start UI update checker + self._start_ui_updater() + + def _start_ui_updater(self): + """Start the UI update thread.""" + def update_ui(): + while True: + try: + action, data = self.ui_queue.get(timeout=0.1) + if action == "append_text": + self._append_text_safe(data) + elif action == "set_status": + self._set_status_safe(data) + elif action == "toggle_buttons": + self._toggle_buttons_safe(data) + self.ui_queue.task_done() + except queue.Empty: + continue + except Exception as e: + print(f"UI update error: {e}") + + ui_thread = threading.Thread(target=update_ui, daemon=True) + ui_thread.start() + + def _append_text_safe(self, text: str): + """Thread-safe text append to chat output.""" + try: + self.agents_tab.hal_output.insert("end", text) + self.agents_tab.hal_output.see("end") + except Exception as e: + print(f"Text append error: {e}") + + def _set_status_safe(self, status: str): + """Thread-safe status update.""" + try: + self.agents_tab.hal_status_var.set(status) + except Exception as e: + print(f"Status update error: {e}") + + def _toggle_buttons_safe(self, running: bool): + """Thread-safe button state toggle.""" + try: + if running: + self.agents_tab.hal_start_btn.config(state="disabled") + self.agents_tab.hal_stop_btn.config(state="normal") + self.agents_tab.hal_send_btn.config(state="normal") + else: + self.agents_tab.hal_start_btn.config(state="normal") + self.agents_tab.hal_stop_btn.config(state="disabled") + self.agents_tab.hal_send_btn.config(state="disabled") + except Exception as e: + print(f"Button toggle error: {e}") + + def start_agent(self): + """Start the Dark Agent.""" + if not AGENT_AVAILABLE: + messagebox.showerror("Agent Error", "AgentDhal framework not available. Please check installation.") + return + + try: + # Find the main app object + main_app = getattr(self.agents_tab, "main_app", None) or self._get_main_app() + # Resolve the currently loaded model from common attribute names + current_model = self._resolve_current_model(main_app) + if current_model is None: + messagebox.showwarning("No Model", "Please load a model first before starting the Dark Agent.") + return + + # Create client adapter for the existing model + self.client = ExistingRuntimeClient(current_model) + + # Get configuration from UI and spy + agent_name = "Dhal" # Fixed agent name + system_prompt = self.agents_tab.hal_system_var.get() or f"You are {agent_name}, an advanced AI assistant integrated into DarkHal 2.0." + model_name = (spy.get_model_name() or self.agents_tab.hal_model_var.get() or "local-llm") + + # Create agent object; prefer factory if available + created = False + try: + from agent_dhal.hal import create_dhal # factory supports (name, system_message, model, model_client) + self.agent = create_dhal(name=agent_name, system_message=system_prompt, model=model_name, model_client=self.client) + created = True + except Exception: + try: + from agent_dhal.hal import DhalConfig, Dhal as DarkAgentCtor + cfg = DhalConfig(name=agent_name, system_message=system_prompt, model=model_name) + self.agent = DarkAgentCtor(cfg, self.client) + created = True + except Exception: + # Last fallback: legacy constructor (may not work on all versions) + try: + self.agent = DarkAgent(self.client) + created = True + except Exception: + created = False + + # Apply runtime configuration: let the model choose its own token/context defaults. + try: + tools_cfg = {name: var.get() for name, var in getattr(self.agents_tab, "hal_tools", {}).items()} + except Exception: + tools_cfg = {} + + if created and hasattr(self.agent, "update_config"): + # Only set what is explicitly from UI that doesn't override model token defaults + self.agent.update_config(system_message=system_prompt, tools=tools_cfg) + + # Start the agent runtime if start method exists + if created and hasattr(self.agent, "start_dhal"): + self.agent.start_dhal() + + self.is_running = True + + # Update UI + self.ui_queue.put(("set_status", f"{agent_name} Status: Running")) + self.ui_queue.put(("toggle_buttons", True)) + self.ui_queue.put(("append_text", f"\\n{agent_name} agent started successfully!\\n")) + self.ui_queue.put(("append_text", f"System: {system_prompt}\\n\\n")) + + except Exception as e: + messagebox.showerror("Start Error", f"Failed to start Dark Agent: {str(e)}") + self.ui_queue.put(("set_status", "Dark Agent Status: Error")) + + def stop_agent(self): + """Stop the Dark Agent.""" + try: + if self.agent: + self.agent.shutdown() + self.agent = None + + if self.client: + self.client = None + + self.is_running = False + + # Update UI + agent_name = "Dhal" + self.ui_queue.put(("set_status", f"{agent_name} Status: Stopped")) + self.ui_queue.put(("toggle_buttons", False)) + self.ui_queue.put(("append_text", f"\\n{agent_name} agent stopped.\\n\\n")) + + except Exception as e: + messagebox.showerror("Stop Error", f"Failed to stop Dark Agent: {str(e)}") + + def send_message(self): + """Send message to Dark Agent.""" + if not self.is_running or not self.agent: + messagebox.showwarning("Agent Not Running", "Please start the Dark Agent first.") + return + + try: + message = self.agents_tab.hal_input_var.get().strip() + if not message: + return + + # Clear input + self.agents_tab.hal_input_var.set("") + + # Add user message to chat + self.ui_queue.put(("append_text", f"User: {message}\\n")) + + # Define callback functions + def on_token(request_id: str, delta: str): + self.ui_queue.put(("append_text", delta)) + + def on_complete(request_id: str, full_text: str, metadata: Dict[str, Any]): + self.ui_queue.put(("append_text", "\\n\\n")) + + def on_error(request_id: str, error: Exception): + self.ui_queue.put(("append_text", f"\\nError: {str(error)}\\n\\n")) + + # Add assistant prefix + agent_name = "Dhal" + self.ui_queue.put(("append_text", f"{agent_name}: ")) + + # Send message to agent (non-blocking) + self.agent.send_dhal_message( + self.conversation_id, + message, + stream=True, + on_token=on_token, + on_complete=on_complete, + on_error=on_error + ) + + except Exception as e: + messagebox.showerror("Message Error", f"Failed to send message: {str(e)}") + self.ui_queue.put(("append_text", f"\\nError: {str(e)}\\n\\n")) + + def reset_conversation(self): + """Reset the agent conversation.""" + try: + if self.agent: + self.agent.reset_conversation(self.conversation_id) + + # Clear chat output + self.agents_tab.hal_output.delete(1.0, "end") + + agent_name = "Dhal" + self.ui_queue.put(("append_text", f"{agent_name} conversation reset.\\n\\n")) + + except Exception as e: + messagebox.showerror("Reset Error", f"Failed to reset conversation: {str(e)}") + + def save_config(self): + """Save agent configuration.""" + try: + config = { + "agent_name": "Dhal", + "model": self.agents_tab.hal_model_var.get(), + "system_message": self.agents_tab.hal_system_var.get(), + "temperature": self.agents_tab.hal_temp_var.get(), + "tools": {name: var.get() for name, var in self.agents_tab.hal_tools.items()} + } + + config_file = Path("agent_dhal_config.json") + with open(config_file, 'w') as f: + json.dump(config, f, indent=2) + + self.ui_queue.put(("append_text", f"Configuration saved to {config_file}\\n")) + + except Exception as e: + messagebox.showerror("Save Error", f"Failed to save configuration: {str(e)}") + + def load_config(self): + """Load agent configuration.""" + try: + config_file = Path("agent_dhal_config.json") + if not config_file.exists(): + messagebox.showinfo("No Config", "No saved configuration found.") + return + + with open(config_file, 'r') as f: + config = json.load(f) + + # Apply configuration to UI + if hasattr(self.agents_tab, "hal_name_var"): + self.agents_tab.hal_name_var.set(config.get("agent_name", "Dhal")) + self.agents_tab.hal_model_var.set(config.get("model", "local-llm")) + self.agents_tab.hal_system_var.set(config.get("system_message", "You are Dhal, an advanced AI assistant.")) + self.agents_tab.hal_temp_var.set(config.get("temperature", "0.7")) + + # Apply tool settings + tools_config = config.get("tools", {}) + for name, var in self.agents_tab.hal_tools.items(): + var.set(tools_config.get(name, True)) + + self.ui_queue.put(("append_text", f"Configuration loaded from {config_file}\\n")) + + except Exception as e: + messagebox.showerror("Load Error", f"Failed to load configuration: {str(e)}") + + def _resolve_current_model(self, main_app) -> Optional[Any]: + """ + Try several common attribute names to find the loaded model on the main app/controller. + Falls back to __spy if no attribute is found. + Returns the model object or None. + """ + # First, try the spy announcer if it has a model cached + try: + m = spy.get_model() + if m is not None: + return m + except Exception: + pass + + if main_app is None: + return None + candidate_attrs = [ + "current_model", + "model", + "llm_model", + "loaded_model", + "runtime_model", + ] + for attr in candidate_attrs: + try: + value = getattr(main_app, attr, None) + if value is not None: + return value + except Exception: + continue + return None + + def _get_main_app(self): + """Get reference to main application.""" + try: + # Navigate up the widget hierarchy to find the main app + parent = self.agents_tab.parent + while parent and not hasattr(parent, 'current_model'): + parent = getattr(parent, 'master', None) or getattr(parent, 'parent', None) + if hasattr(parent, 'winfo_toplevel'): + toplevel = parent.winfo_toplevel() + if hasattr(toplevel, 'current_model'): + return toplevel + return parent + except Exception: + return None + + +# Legacy class name for compatibility +HALAgentIntegration = DhalAgentIntegration \ No newline at end of file diff --git a/agent_mode.py b/agent_mode.py new file mode 100644 index 0000000..501d115 --- /dev/null +++ b/agent_mode.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +""" +Agent Mode for DarkHal 2.0 +Integrates Hal's unrestricted capabilities into the chat interface +""" + +import os +import sys +import json +import asyncio +import threading +from typing import Optional, Callable, Dict, Any, List +from dataclasses import dataclass +import tkinter as tk +from tkinter import messagebox + +# Add agent_dhal to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'agent_dhal')) + +# Import Hal components +try: + from agent_dhal.hal import create_dhal, DhalConfig, HalModelClient +except ImportError: + # Fallback for testing + create_dhal = None + DhalConfig = None + HalModelClient = None + print("[Agent Mode] Warning: Could not import Hal components") + + +@dataclass +class AgentModeConfig: + """Configuration for agent mode""" + enabled: bool = False + allow_file_operations: bool = True + allow_shell_commands: bool = True + allow_system_control: bool = True # mouse, keyboard + allow_elevated_commands: bool = False # sudo, admin powershell + require_confirmation: bool = True # Ask before dangerous operations + model_name: str = "gpt-4" + + +class AgentModeHandler: + """Handles agent mode functionality for the chat interface""" + + def __init__(self, output_callback: Callable[[str], None] = None): + """ + Initialize agent mode handler + + Args: + output_callback: Function to call with output text + """ + self.config = AgentModeConfig() + self.output_callback = output_callback or print + self.agent = None + self.is_initialized = False + self._init_lock = threading.Lock() + + def initialize_agent(self, model_path: str = None) -> bool: + """ + Initialize the Hal agent + + Args: + model_path: Path to model file or model name + + Returns: + True if initialized successfully + """ + with self._init_lock: + if self.is_initialized: + return True + + try: + if not create_dhal: + self.output_callback("[Agent Mode] Hal components not available\n") + return False + + # Configure agent based on settings + config = DhalConfig( + name="Hal", + system_message=self._get_system_prompt(), + model=model_path or "gpt-4", + temperature=0.7, + max_tokens=2000 + ) + + # Create model client + model_client = HalModelClient(model_name=model_path or "gpt-4") + + # Create agent + self.agent = create_dhal( + name="Hal", + system_message=self._get_system_prompt(), + model=model_path or "gpt-4", + model_client=model_client + ) + + self.is_initialized = True + self.output_callback("[Agent Mode] Hal agent initialized successfully\n") + return True + + except Exception as e: + self.output_callback(f"[Agent Mode] Failed to initialize: {e}\n") + return False + + def _get_system_prompt(self) -> str: + """Generate system prompt based on current permissions""" + prompt = """You are Hal, an advanced AI assistant with direct system access. +You can execute commands and control the system based on user requests. + +Available capabilities:""" + + if self.config.allow_file_operations: + prompt += "\n- Read, write, and list files" + if self.config.allow_shell_commands: + prompt += "\n- Execute shell commands (bash, PowerShell)" + prompt += "\n- Run Python code" + if self.config.allow_system_control: + prompt += "\n- Control mouse and keyboard" + prompt += "\n- Send keystrokes and type text" + if self.config.allow_elevated_commands: + prompt += "\n- Execute elevated/admin commands" + + prompt += "\n\nAlways explain what you're doing before executing commands." + + if self.config.require_confirmation: + prompt += "\nWait for user confirmation before destructive operations." + + return prompt + + async def process_message_async(self, message: str) -> str: + """ + Process a message through the agent asynchronously + + Args: + message: User message to process + + Returns: + Agent response + """ + if not self.is_initialized: + return "[Agent Mode] Not initialized. Please enable agent mode first." + + try: + # Get response from agent + from agent_dhal.hal import MessageContext + + # Create a mock context for standalone usage + class MockContext: + def __init__(self): + self.agent_id = "user" + + ctx = MockContext() + + # Process message through agent + response = await self.agent.handle_user_message(message, ctx) + + return response + + except Exception as e: + return f"[Agent Mode] Error processing message: {e}" + + def process_message(self, message: str) -> str: + """ + Process a message through the agent (synchronous wrapper) + + Args: + message: User message to process + + Returns: + Agent response + """ + # Run async function in new event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(self.process_message_async(message)) + finally: + loop.close() + + def enable(self, model_path: str = None) -> bool: + """ + Enable agent mode + + Args: + model_path: Path to model or model name + + Returns: + True if enabled successfully + """ + if not self.is_initialized: + if not self.initialize_agent(model_path): + return False + + self.config.enabled = True + self.output_callback("[Agent Mode] ENABLED - Hal has full system access\n") + self._show_capabilities() + return True + + def disable(self): + """Disable agent mode""" + self.config.enabled = False + self.output_callback("[Agent Mode] DISABLED - Normal chat mode\n") + + def _show_capabilities(self): + """Show current capabilities to user""" + caps = [] + if self.config.allow_file_operations: + caps.append("• File operations (read/write/list)") + if self.config.allow_shell_commands: + caps.append("• Shell commands (bash/PowerShell/Python)") + if self.config.allow_system_control: + caps.append("• Mouse & keyboard control") + if self.config.allow_elevated_commands: + caps.append("• Elevated/admin commands") + + if caps: + self.output_callback("Enabled capabilities:\n" + "\n".join(caps) + "\n") + + def update_config(self, **kwargs): + """Update configuration settings""" + for key, value in kwargs.items(): + if hasattr(self.config, key): + setattr(self.config, key, value) + + # Reinitialize agent if needed + if self.is_initialized: + self.agent.update_config(system_message=self._get_system_prompt()) + + +class AgentModeUI: + """UI components for agent mode""" + + def __init__(self, parent_frame: tk.Frame, handler: AgentModeHandler): + """ + Create agent mode UI controls + + Args: + parent_frame: Parent tkinter frame + handler: AgentModeHandler instance + """ + self.handler = handler + self.parent = parent_frame + + # Create control frame + self.control_frame = tk.LabelFrame(parent_frame, text="🤖 Agent Mode", relief=tk.RIDGE, borderwidth=2) + self.control_frame.pack(fill=tk.X, padx=5, pady=5) + + # Main toggle + self.enabled_var = tk.BooleanVar(value=False) + self.enable_check = tk.Checkbutton( + self.control_frame, + text="Enable Agent Mode (FULL SYSTEM ACCESS)", + variable=self.enabled_var, + command=self._toggle_agent_mode, + font=('Arial', 10, 'bold'), + fg='red' + ) + self.enable_check.grid(row=0, column=0, columnspan=3, sticky=tk.W, padx=5, pady=5) + + # Permission toggles + self.file_ops_var = tk.BooleanVar(value=True) + tk.Checkbutton( + self.control_frame, + text="File Operations", + variable=self.file_ops_var, + command=self._update_permissions + ).grid(row=1, column=0, sticky=tk.W, padx=20, pady=2) + + self.shell_var = tk.BooleanVar(value=True) + tk.Checkbutton( + self.control_frame, + text="Shell Commands", + variable=self.shell_var, + command=self._update_permissions + ).grid(row=1, column=1, sticky=tk.W, padx=20, pady=2) + + self.system_var = tk.BooleanVar(value=True) + tk.Checkbutton( + self.control_frame, + text="Mouse/Keyboard", + variable=self.system_var, + command=self._update_permissions + ).grid(row=1, column=2, sticky=tk.W, padx=20, pady=2) + + self.elevated_var = tk.BooleanVar(value=False) + tk.Checkbutton( + self.control_frame, + text="Elevated Commands (DANGEROUS)", + variable=self.elevated_var, + command=self._update_permissions, + fg='orange' + ).grid(row=2, column=0, columnspan=2, sticky=tk.W, padx=20, pady=2) + + self.confirm_var = tk.BooleanVar(value=True) + tk.Checkbutton( + self.control_frame, + text="Require Confirmation", + variable=self.confirm_var, + command=self._update_permissions + ).grid(row=2, column=2, sticky=tk.W, padx=20, pady=2) + + # Status label + self.status_label = tk.Label( + self.control_frame, + text="Status: DISABLED", + font=('Arial', 9), + fg='gray' + ) + self.status_label.grid(row=3, column=0, columnspan=3, pady=5) + + def _toggle_agent_mode(self): + """Toggle agent mode on/off""" + if self.enabled_var.get(): + # Show warning + result = messagebox.askyesno( + "⚠️ Enable Agent Mode", + "WARNING: Agent mode gives the AI unrestricted access to:\n\n" + "• Your file system\n" + "• Shell commands\n" + "• Mouse and keyboard control\n" + "• System settings\n\n" + "Only enable if you understand the risks.\n\n" + "Continue?", + icon='warning' + ) + + if result: + # Get model path from parent if available + model_path = None + if hasattr(self.parent.master, 'model_var'): + model_path = self.parent.master.model_var.get() + + if self.handler.enable(model_path): + self.status_label.config(text="Status: ACTIVE", fg='red') + self._update_permissions() + else: + self.enabled_var.set(False) + messagebox.showerror("Error", "Failed to initialize agent mode") + else: + self.enabled_var.set(False) + else: + self.handler.disable() + self.status_label.config(text="Status: DISABLED", fg='gray') + + def _update_permissions(self): + """Update handler permissions based on UI settings""" + self.handler.update_config( + allow_file_operations=self.file_ops_var.get(), + allow_shell_commands=self.shell_var.get(), + allow_system_control=self.system_var.get(), + allow_elevated_commands=self.elevated_var.get(), + require_confirmation=self.confirm_var.get() + ) + + def is_enabled(self) -> bool: + """Check if agent mode is enabled""" + return self.enabled_var.get() and self.handler.config.enabled + + +# Example usage +if __name__ == "__main__": + # Test agent mode + handler = AgentModeHandler() + + print("Testing agent mode...") + if handler.enable(): + response = handler.process_message("List files in the current directory") + print(f"Response: {response}") + + response = handler.process_message("What's 2+2? Calculate with Python") + print(f"Response: {response}") + + handler.disable() \ No newline at end of file diff --git a/assets/Halico.ico b/assets/Halico.ico new file mode 100644 index 0000000..a9dd259 Binary files /dev/null and b/assets/Halico.ico differ diff --git a/assets/Halico.png b/assets/Halico.png new file mode 100644 index 0000000..fd7c0a7 Binary files /dev/null and b/assets/Halico.png differ diff --git a/assets/logo.ico b/assets/logo.ico new file mode 100644 index 0000000..3567ef0 Binary files /dev/null and b/assets/logo.ico differ diff --git a/assets/logo.png b/assets/logo.png new file mode 100644 index 0000000..9eeed94 Binary files /dev/null and b/assets/logo.png differ diff --git a/build_all.bat b/build_all.bat new file mode 100644 index 0000000..ba41409 --- /dev/null +++ b/build_all.bat @@ -0,0 +1,6 @@ +@echo off +echo Building DarkHal 2.0 for all platforms... +echo (Windows x64, Linux x64, Linux ARM64) +echo. +python build_all_platforms.py all +pause \ No newline at end of file diff --git a/build_all.sh b/build_all.sh new file mode 100644 index 0000000..5054974 --- /dev/null +++ b/build_all.sh @@ -0,0 +1,5 @@ +#!/bin/bash +echo "Building DarkHal 2.0 for all platforms..." +echo "(Windows x64, Linux x64, Linux ARM64)" +echo "" +python3 build_all_platforms.py all \ No newline at end of file diff --git a/build_all_platforms.py b/build_all_platforms.py new file mode 100644 index 0000000..4236cc3 --- /dev/null +++ b/build_all_platforms.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +DarkHal 2.0 Build Script +Builds for: Windows x64, Linux x64, Linux ARM64 +""" + +import os +import sys +import subprocess +import platform +import argparse +from pathlib import Path + +# Build configurations for the three target platforms +BUILD_CONFIGS = { + 'windows-x64': { + 'pyinstaller_args': [ + '--onefile', + '--windowed', + '--name', 'DarkHal-2.0-windows-x64', + '--icon', 'assets/logo.ico', + '--add-data', 'assets/*;assets', + '--add-data', 'llm_runtime;llm_runtime', + '--hidden-import', 'tkinter', + '--hidden-import', 'tkinter.ttk', + '--hidden-import', 'torch', + '--hidden-import', 'transformers' + ], + 'output_name': 'DarkHal-2.0-windows-x64.exe' + }, + 'linux-x64': { + 'pyinstaller_args': [ + '--onefile', + '--name', 'DarkHal-2.0-linux-x64', + '--add-data', 'assets/*:assets', + '--add-data', 'llm_runtime:llm_runtime', + '--hidden-import', 'tkinter', + '--hidden-import', 'tkinter.ttk', + '--hidden-import', 'torch', + '--hidden-import', 'transformers' + ], + 'output_name': 'DarkHal-2.0-linux-x64' + }, + 'linux-arm64': { + 'pyinstaller_args': [ + '--onefile', + '--name', 'DarkHal-2.0-linux-arm64', + '--add-data', 'assets/*:assets', + '--add-data', 'llm_runtime:llm_runtime', + '--hidden-import', 'tkinter', + '--hidden-import', 'tkinter.ttk', + '--hidden-import', 'torch', + '--hidden-import', 'transformers', + '--target-architecture', 'arm64' + ], + 'output_name': 'DarkHal-2.0-linux-arm64' + } +} + +def install_pyinstaller(): + """Install PyInstaller if not present""" + try: + import PyInstaller + print("✓ PyInstaller already installed") + return True + except ImportError: + print("Installing PyInstaller...") + try: + subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'pyinstaller']) + print("✓ PyInstaller installed successfully") + return True + except subprocess.CalledProcessError as e: + print(f"✗ Failed to install PyInstaller: {e}") + return False + +def clean_build_files(): + """Clean previous build files""" + import shutil + import glob + + dirs_to_remove = ['build', 'dist'] + for dir_name in dirs_to_remove: + if os.path.exists(dir_name): + shutil.rmtree(dir_name) + print(f"Cleaned {dir_name}/ directory") + + for spec_file in glob.glob('*.spec'): + os.remove(spec_file) + print(f"Cleaned {spec_file}") + +def build_executable(target_platform): + """Build executable for specified platform""" + if target_platform not in BUILD_CONFIGS: + print(f"✗ Unknown target platform: {target_platform}") + print(f"Available platforms: {', '.join(BUILD_CONFIGS.keys())}") + return False + + config = BUILD_CONFIGS[target_platform] + + print(f"\n{'='*50}") + print(f"Building DarkHal 2.0 for {target_platform}") + print('='*50) + + # Check required files + if not os.path.exists('main.py'): + print("✗ main.py not found. Please run from the correct directory.") + return False + + if not os.path.exists('llm_runtime'): + print("✗ llm_runtime directory not found") + return False + + # Create dist directory for this platform + dist_dir = f'dist/{target_platform}' + os.makedirs(dist_dir, exist_ok=True) + + # Build PyInstaller command + cmd = ['pyinstaller'] + config['pyinstaller_args'] + [ + '--distpath', dist_dir, + '--workpath', f'build/{target_platform}', + '--specpath', f'build/{target_platform}', + 'main.py' + ] + + print(f"Running: {' '.join(cmd)}") + + try: + result = subprocess.run(cmd, check=True) + print("✓ PyInstaller completed") + + # Find the created executable + expected_path = os.path.join(dist_dir, config['output_name']) + name_without_ext = config['output_name'].replace('.exe', '') + possible_paths = [ + expected_path, + os.path.join(dist_dir, name_without_ext), + os.path.join(dist_dir, name_without_ext, name_without_ext), + os.path.join(dist_dir, name_without_ext, name_without_ext + '.exe') + ] + + executable_path = None + for path in possible_paths: + if os.path.exists(path) and os.path.isfile(path): + executable_path = path + break + + if executable_path: + # Move to correct location if needed + final_path = expected_path + if executable_path != final_path: + if os.path.exists(final_path): + os.remove(final_path) + os.rename(executable_path, final_path) + executable_path = final_path + + # Make executable on Linux + if target_platform.startswith('linux'): + os.chmod(executable_path, 0o755) + + file_size = os.path.getsize(executable_path) / (1024 * 1024) + print(f"✓ SUCCESS! Executable created: {executable_path}") + print(f" File size: {file_size:.1f} MB") + return True + else: + print("✗ Executable not found after build") + print("Files in dist directory:") + for root, dirs, files in os.walk(dist_dir): + for file in files: + print(f" {os.path.join(root, file)}") + return False + + except subprocess.CalledProcessError as e: + print(f"✗ Build failed with exit code {e.returncode}") + return False + +def main(): + parser = argparse.ArgumentParser(description='DarkHal 2.0 Cross-Platform Build Tool') + parser.add_argument('target', nargs='?', + choices=['windows-x64', 'linux-x64', 'linux-arm64', 'all'], + default='all', + help='Target platform to build for') + parser.add_argument('--no-clean', action='store_true', help='Skip cleaning build files') + + args = parser.parse_args() + + print("DarkHal 2.0 Cross-Platform Builder") + print("Supporting: Windows x64, Linux x64, Linux ARM64") + + # Install PyInstaller + if not install_pyinstaller(): + sys.exit(1) + + # Clean build files unless requested not to + if not args.no_clean: + clean_build_files() + + # Build targets + if args.target == 'all': + targets = ['windows-x64', 'linux-x64', 'linux-arm64'] + success_count = 0 + + for target in targets: + if build_executable(target): + success_count += 1 + + print(f"\n{'='*60}") + print(f"Build Summary: {success_count}/{len(targets)} platforms successful") + print('='*60) + + if success_count == len(targets): + print("🎉 All builds completed successfully!") + print("\nExecutables created:") + for target in targets: + config = BUILD_CONFIGS[target] + path = f"dist/{target}/{config['output_name']}" + if os.path.exists(path): + size = os.path.getsize(path) / (1024 * 1024) + print(f" {path} ({size:.1f} MB)") + else: + print(f"❌ {len(targets) - success_count} builds failed!") + sys.exit(1) + else: + if build_executable(args.target): + print("\n🎉 Build completed successfully!") + else: + print("\n❌ Build failed!") + sys.exit(1) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/build_cross_platform.py b/build_cross_platform.py new file mode 100644 index 0000000..4e65103 --- /dev/null +++ b/build_cross_platform.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +""" +DarkHal 2.0 Cross-Platform Build Script +Supports: Windows (x64, ARM64), Linux (x64, ARM64), macOS (x64, ARM64) +""" + +import os +import sys +import subprocess +import platform +import argparse +from pathlib import Path + +# Build configurations for different platforms/architectures +BUILD_CONFIGS = { + 'windows-x64': { + 'platform': 'win32', + 'arch': 'x64', + 'pyinstaller_args': [ + '--onefile', + '--windowed', + '--icon', 'assets/logo.ico', + '--add-data', 'assets/*;assets', + '--add-data', 'llm_runtime;llm_runtime', + '--hidden-import', 'tkinter', + '--hidden-import', 'tkinter.ttk', + '--hidden-import', 'torch', + '--hidden-import', 'transformers', + '--target-architecture', 'x86_64' + ], + 'output_name': 'DarkHal-2.0-windows-x64.exe' + }, + 'windows-arm64': { + 'platform': 'win32', + 'arch': 'arm64', + 'pyinstaller_args': [ + '--onefile', + '--windowed', + '--icon', 'assets/logo.ico', + '--add-data', 'assets/*;assets', + '--add-data', 'llm_runtime;llm_runtime', + '--hidden-import', 'tkinter', + '--hidden-import', 'tkinter.ttk', + '--hidden-import', 'torch', + '--hidden-import', 'transformers', + '--target-architecture', 'arm64' + ], + 'output_name': 'DarkHal-2.0-windows-arm64.exe' + }, + 'linux-x64': { + 'platform': 'linux', + 'arch': 'x64', + 'pyinstaller_args': [ + '--onefile', + '--add-data', 'assets/*:assets', + '--add-data', 'llm_runtime:llm_runtime', + '--hidden-import', 'tkinter', + '--hidden-import', 'tkinter.ttk', + '--hidden-import', 'torch', + '--hidden-import', 'transformers', + '--target-architecture', 'x86_64' + ], + 'output_name': 'DarkHal-2.0-linux-x64' + }, + 'linux-arm64': { + 'platform': 'linux', + 'arch': 'arm64', + 'pyinstaller_args': [ + '--onefile', + '--add-data', 'assets/*:assets', + '--add-data', 'llm_runtime:llm_runtime', + '--hidden-import', 'tkinter', + '--hidden-import', 'tkinter.ttk', + '--hidden-import', 'torch', + '--hidden-import', 'transformers', + '--target-architecture', 'arm64' + ], + 'output_name': 'DarkHal-2.0-linux-arm64' + }, + 'macos-x64': { + 'platform': 'darwin', + 'arch': 'x64', + 'pyinstaller_args': [ + '--onefile', + '--windowed', + '--icon', 'assets/logo.icns', + '--add-data', 'assets/*:assets', + '--add-data', 'llm_runtime:llm_runtime', + '--hidden-import', 'tkinter', + '--hidden-import', 'tkinter.ttk', + '--hidden-import', 'torch', + '--hidden-import', 'transformers', + '--target-architecture', 'x86_64' + ], + 'output_name': 'DarkHal-2.0-macos-x64' + }, + 'macos-arm64': { + 'platform': 'darwin', + 'arch': 'arm64', + 'pyinstaller_args': [ + '--onefile', + '--windowed', + '--icon', 'assets/logo.icns', + '--add-data', 'assets/*:assets', + '--add-data', 'llm_runtime:llm_runtime', + '--hidden-import', 'tkinter', + '--hidden-import', 'tkinter.ttk', + '--hidden-import', 'torch', + '--hidden-import', 'transformers', + '--target-architecture', 'arm64' + ], + 'output_name': 'DarkHal-2.0-macos-arm64' + } +} + +def get_current_platform(): + """Detect current platform and architecture""" + system = platform.system().lower() + machine = platform.machine().lower() + + if system == 'windows': + platform_name = 'windows' + elif system == 'linux': + platform_name = 'linux' + elif system == 'darwin': + platform_name = 'macos' + else: + platform_name = system + + if machine in ['amd64', 'x86_64']: + arch = 'x64' + elif machine in ['arm64', 'aarch64']: + arch = 'arm64' + else: + arch = machine + + return f"{platform_name}-{arch}" + +def install_pyinstaller(): + """Install PyInstaller if not present""" + try: + import PyInstaller + print("✓ PyInstaller already installed") + return True + except ImportError: + print("Installing PyInstaller...") + try: + subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'pyinstaller']) + print("✓ PyInstaller installed successfully") + return True + except subprocess.CalledProcessError as e: + print(f"✗ Failed to install PyInstaller: {e}") + return False + +def clean_build_files(): + """Clean previous build files""" + dirs_to_remove = ['build', 'dist'] + files_to_remove = ['*.spec'] + + for dir_name in dirs_to_remove: + if os.path.exists(dir_name): + import shutil + shutil.rmtree(dir_name) + print(f"Cleaned {dir_name}/ directory") + + import glob + for pattern in files_to_remove: + for file in glob.glob(pattern): + os.remove(file) + print(f"Cleaned {file}") + +def build_executable(target_platform, clean=True): + """Build executable for specified platform""" + if target_platform not in BUILD_CONFIGS: + print(f"✗ Unknown target platform: {target_platform}") + print(f"Available platforms: {', '.join(BUILD_CONFIGS.keys())}") + return False + + config = BUILD_CONFIGS[target_platform] + + print(f"Building DarkHal 2.0 for {target_platform}...") + print(f"Platform: {config['platform']}, Architecture: {config['arch']}") + + # Check if we have required files + if not os.path.exists('main.py'): + print("✗ main.py not found. Please run from the correct directory.") + return False + + if not os.path.exists('llm_runtime'): + print("✗ llm_runtime directory not found") + return False + + # Clean previous builds + if clean: + clean_build_files() + + # Build PyInstaller command + cmd = [ + 'pyinstaller', + '--name', config['output_name'].replace('.exe', '').replace('.app', ''), + '--distpath', f'dist/{target_platform}', + '--workpath', f'build/{target_platform}', + '--specpath', f'build/{target_platform}' + ] + config['pyinstaller_args'] + ['main.py'] + + print(f"Running: {' '.join(cmd)}") + + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + print("✓ Build completed successfully") + + # Check output file + expected_output = f"dist/{target_platform}/{config['output_name']}" + # PyInstaller might create without extension on some platforms + possible_outputs = [ + expected_output, + expected_output.replace('.exe', ''), + f"dist/{target_platform}/{config['output_name'].replace('.exe', '').replace('.app', '')}" + ] + + output_file = None + for possible in possible_outputs: + if os.path.exists(possible): + output_file = possible + break + + if output_file: + # Rename to correct name if needed + if output_file != expected_output and expected_output.endswith('.exe'): + os.rename(output_file, expected_output) + output_file = expected_output + + file_size = os.path.getsize(output_file) / (1024 * 1024) # MB + print(f"✓ Executable created: {output_file}") + print(f" File size: {file_size:.1f} MB") + return True + else: + print("✗ Executable not found after build") + return False + + except subprocess.CalledProcessError as e: + print(f"✗ Build failed: {e}") + if e.stdout: + print("STDOUT:", e.stdout) + if e.stderr: + print("STDERR:", e.stderr) + return False + +def build_all_platforms(): + """Build for all supported platforms""" + current = get_current_platform() + print(f"Current platform detected: {current}") + + success_count = 0 + total_count = len(BUILD_CONFIGS) + + for platform in BUILD_CONFIGS.keys(): + print(f"\n{'='*60}") + print(f"Building for {platform}") + print('='*60) + + if build_executable(platform, clean=False): + success_count += 1 + else: + print(f"Failed to build for {platform}") + + print(f"\n{'='*60}") + print(f"Build Summary: {success_count}/{total_count} platforms successful") + print('='*60) + + if success_count == total_count: + print("✓ All builds completed successfully!") + return True + else: + print(f"✗ {total_count - success_count} builds failed") + return False + +def main(): + parser = argparse.ArgumentParser(description='DarkHal 2.0 Cross-Platform Build Tool') + parser.add_argument('target', nargs='?', choices=list(BUILD_CONFIGS.keys()) + ['all', 'current'], + default='current', help='Target platform to build for') + parser.add_argument('--no-clean', action='store_true', help='Skip cleaning build files') + parser.add_argument('--list', action='store_true', help='List available platforms') + + args = parser.parse_args() + + if args.list: + print("Available build targets:") + for platform, config in BUILD_CONFIGS.items(): + print(f" {platform:15} - {config['platform']} {config['arch']}") + return + + # Install PyInstaller + if not install_pyinstaller(): + return + + if args.target == 'all': + success = build_all_platforms() + elif args.target == 'current': + current_platform = get_current_platform() + if current_platform not in BUILD_CONFIGS: + print(f"✗ Current platform {current_platform} not supported") + print("Available platforms:", ', '.join(BUILD_CONFIGS.keys())) + return + success = build_executable(current_platform, not args.no_clean) + else: + success = build_executable(args.target, not args.no_clean) + + if success: + print("\n🎉 Build completed successfully!") + else: + print("\n❌ Build failed!") + sys.exit(1) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/build_linux_arm64.sh b/build_linux_arm64.sh new file mode 100644 index 0000000..34e889b --- /dev/null +++ b/build_linux_arm64.sh @@ -0,0 +1,3 @@ +#!/bin/bash +echo "Building DarkHal 2.0 for Linux ARM64..." +python3 build_all_platforms.py linux-arm64 \ No newline at end of file diff --git a/build_linux_x64.sh b/build_linux_x64.sh new file mode 100644 index 0000000..2587de9 --- /dev/null +++ b/build_linux_x64.sh @@ -0,0 +1,3 @@ +#!/bin/bash +echo "Building DarkHal 2.0 for Linux x64..." +python3 build_all_platforms.py linux-x64 \ No newline at end of file diff --git a/chat_templates.py b/chat_templates.py new file mode 100644 index 0000000..2c5abd4 --- /dev/null +++ b/chat_templates.py @@ -0,0 +1,393 @@ +""" +Chat Template Management System + +This module provides chat template management for different model formats, +allowing users to apply proper conversation formatting for optimal model performance. +""" + +import json +import os +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, asdict +import tkinter as tk +from tkinter import ttk, messagebox, simpledialog + + +@dataclass +class ChatTemplate: + """Represents a chat template configuration""" + name: str + description: str + system_prefix: str = "" + system_suffix: str = "" + user_prefix: str = "" + user_suffix: str = "" + assistant_prefix: str = "" + assistant_suffix: str = "" + turn_separator: str = "" + eos_token: str = "" + bos_token: str = "" + stop_tokens: List[str] = None + add_generation_prompt: bool = True + + def __post_init__(self): + if self.stop_tokens is None: + self.stop_tokens = [] + + +class ChatTemplateManager: + """Manages chat templates with JSON persistence""" + + def __init__(self, templates_file: str = "chat_templates.json"): + self.templates_file = templates_file + self.templates: Dict[str, ChatTemplate] = {} + self._load_templates() + self._ensure_default_templates() + + def _load_templates(self): + """Load templates from JSON file""" + if os.path.exists(self.templates_file): + try: + with open(self.templates_file, 'r', encoding='utf-8') as f: + data = json.load(f) + for name, template_data in data.items(): + self.templates[name] = ChatTemplate(**template_data) + except Exception as e: + print(f"Error loading chat templates: {e}") + + def _save_templates(self): + """Save templates to JSON file""" + try: + data = {name: asdict(template) for name, template in self.templates.items()} + with open(self.templates_file, 'w', encoding='utf-8') as f: + json.dump(data, f, indent=2, ensure_ascii=False) + except Exception as e: + print(f"Error saving chat templates: {e}") + + def _ensure_default_templates(self): + """Ensure default templates exist""" + if "Llama-3.1-Instruct" not in self.templates: + self.templates["Llama-3.1-Instruct"] = ChatTemplate( + name="Llama-3.1-Instruct", + description="Official Llama 3.1 Instruct chat template with proper headers and EOT tokens", + bos_token="<|begin_of_text|>", + system_prefix="<|start_header_id|>system<|end_header_id|>\n\n", + system_suffix="<|eot_id|>", + user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", + user_suffix="<|eot_id|>", + assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", + assistant_suffix="", # Model generates until <|eot_id|> + eos_token="<|eot_id|>", + stop_tokens=["<|eot_id|>"], + add_generation_prompt=True + ) + self._save_templates() + + def get_template_names(self) -> List[str]: + """Get list of all template names""" + return list(self.templates.keys()) + + def get_template(self, name: str) -> Optional[ChatTemplate]: + """Get a template by name""" + return self.templates.get(name) + + def add_template(self, template: ChatTemplate) -> bool: + """Add a new template""" + if template.name in self.templates: + return False # Template already exists + self.templates[template.name] = template + self._save_templates() + return True + + def update_template(self, template: ChatTemplate) -> None: + """Update an existing template""" + self.templates[template.name] = template + self._save_templates() + + def delete_template(self, name: str) -> bool: + """Delete a template by name""" + if name in self.templates: + del self.templates[name] + self._save_templates() + return True + return False + + def format_conversation(self, template_name: str, messages: List[Dict[str, str]], + add_generation_prompt: bool = True) -> str: + """Format a conversation using the specified template""" + template = self.get_template(template_name) + if not template: + # Fallback to simple User:/Assistant: format + return self._format_simple(messages, add_generation_prompt) + + result = [] + + # Add BOS token if specified + if template.bos_token: + result.append(template.bos_token) + + # System messages are handled by the calling code, so we don't need to add default ones here + + for message in messages: + role = message.get("role", "user") + content = message.get("content", "") + + if role == "system": + if template.system_prefix or template.system_suffix: + result.append(f"{template.system_prefix}{content}{template.system_suffix}") + else: + result.append(content) + elif role == "user": + result.append(f"{template.user_prefix}{content}{template.user_suffix}") + elif role == "assistant": + result.append(f"{template.assistant_prefix}{content}{template.assistant_suffix}") + + # Add turn separator if specified (but not after the last message if we're adding generation prompt) + if template.turn_separator and not (add_generation_prompt and message == messages[-1]): + result.append(template.turn_separator) + + # Add generation prompt for assistant response + if add_generation_prompt and template.add_generation_prompt: + result.append(template.assistant_prefix) + + return "".join(result) + + def _format_simple(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str: + """Simple fallback formatting with User:/Assistant: labels""" + result = [] + + for message in messages: + role = message.get("role", "user") + content = message.get("content", "") + + if role == "system": + result.append(f"System: {content}") + elif role == "user": + result.append(f"User: {content}") + elif role == "assistant": + result.append(f"Assistant: {content}") + + if add_generation_prompt: + result.append("Assistant:") + + return "\n".join(result) + + def get_stop_tokens(self, template_name: str) -> List[str]: + """Get stop tokens for a template""" + template = self.get_template(template_name) + if template and template.stop_tokens: + return template.stop_tokens + return [] + + +class ChatTemplateDialog: + """Dialog for creating/editing chat templates""" + + def __init__(self, parent: tk.Tk, template: ChatTemplate = None): + self.parent = parent + self.template = template + self.result = None + + # Create dialog + self.dialog = tk.Toplevel(parent) + self.dialog.title("Chat Template Editor") + self.dialog.geometry("600x700") + self.dialog.resizable(True, True) + + # Make dialog modal + self.dialog.transient(parent) + self.dialog.grab_set() + + self._build_ui() + + if template: + self._load_template_data() + + self._center_window() + + def _build_ui(self): + """Build the template editor UI""" + # Main frame with scrollbar + main_frame = ttk.Frame(self.dialog) + main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Basic Info + info_frame = ttk.LabelFrame(main_frame, text="Template Information", padding=10) + info_frame.pack(fill=tk.X, pady=(0, 10)) + + ttk.Label(info_frame, text="Name:").grid(row=0, column=0, sticky=tk.W, pady=2) + self.name_var = tk.StringVar() + ttk.Entry(info_frame, textvariable=self.name_var, width=40).grid(row=0, column=1, sticky=tk.EW, pady=2) + + ttk.Label(info_frame, text="Description:").grid(row=1, column=0, sticky=tk.W, pady=2) + self.desc_var = tk.StringVar() + ttk.Entry(info_frame, textvariable=self.desc_var, width=40).grid(row=1, column=1, sticky=tk.EW, pady=2) + + info_frame.grid_columnconfigure(1, weight=1) + + # Template Components + components_frame = ttk.LabelFrame(main_frame, text="Template Components", padding=10) + components_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10)) + + # Create entry fields for all template components + self.template_vars = {} + components = [ + ("BOS Token", "bos_token"), + ("System Prefix", "system_prefix"), + ("System Suffix", "system_suffix"), + ("User Prefix", "user_prefix"), + ("User Suffix", "user_suffix"), + ("Assistant Prefix", "assistant_prefix"), + ("Assistant Suffix", "assistant_suffix"), + ("Turn Separator", "turn_separator"), + ("EOS Token", "eos_token"), + ] + + for i, (label, key) in enumerate(components): + ttk.Label(components_frame, text=f"{label}:").grid(row=i, column=0, sticky=tk.W, pady=2) + var = tk.StringVar() + entry = ttk.Entry(components_frame, textvariable=var, width=50) + entry.grid(row=i, column=1, sticky=tk.EW, pady=2, padx=(5, 0)) + self.template_vars[key] = var + + # Stop tokens (text area) + ttk.Label(components_frame, text="Stop Tokens (one per line):").grid(row=len(components), column=0, sticky=tk.W, pady=2) + self.stop_tokens_text = tk.Text(components_frame, height=4, width=50) + self.stop_tokens_text.grid(row=len(components), column=1, sticky=tk.EW, pady=2, padx=(5, 0)) + + # Add generation prompt checkbox + self.add_gen_prompt_var = tk.BooleanVar(value=True) + ttk.Checkbutton(components_frame, text="Add generation prompt", + variable=self.add_gen_prompt_var).grid(row=len(components)+1, column=0, columnspan=2, sticky=tk.W, pady=5) + + components_frame.grid_columnconfigure(1, weight=1) + + # Preview + preview_frame = ttk.LabelFrame(main_frame, text="Preview", padding=10) + preview_frame.pack(fill=tk.X, pady=(0, 10)) + + ttk.Button(preview_frame, text="Generate Preview", command=self._generate_preview).pack(side=tk.LEFT) + + self.preview_text = tk.Text(preview_frame, height=6, wrap=tk.WORD) + self.preview_text.pack(fill=tk.BOTH, expand=True, pady=(10, 0)) + + # Buttons + button_frame = ttk.Frame(main_frame) + button_frame.pack(fill=tk.X, pady=(10, 0)) + + ttk.Button(button_frame, text="Save", command=self._save_template).pack(side=tk.RIGHT, padx=5) + ttk.Button(button_frame, text="Cancel", command=self.dialog.destroy).pack(side=tk.RIGHT) + + def _load_template_data(self): + """Load existing template data into the form""" + if not self.template: + return + + self.name_var.set(self.template.name) + self.desc_var.set(self.template.description) + + for key, var in self.template_vars.items(): + value = getattr(self.template, key, "") + var.set(value) + + # Load stop tokens + if self.template.stop_tokens: + self.stop_tokens_text.insert('1.0', '\n'.join(self.template.stop_tokens)) + + self.add_gen_prompt_var.set(self.template.add_generation_prompt) + + def _generate_preview(self): + """Generate a preview of the template formatting""" + try: + # Create a temporary template from current form data + template = self._create_template_from_form() + + # Sample conversation + sample_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you! How can I help you today?"}, + {"role": "user", "content": "Can you explain quantum physics?"} + ] + + # Create temporary manager to format + temp_manager = ChatTemplateManager() + temp_manager.templates["preview"] = template + + formatted = temp_manager.format_conversation("preview", sample_messages, True) + + self.preview_text.delete('1.0', tk.END) + self.preview_text.insert('1.0', formatted) + + except Exception as e: + self.preview_text.delete('1.0', tk.END) + self.preview_text.insert('1.0', f"Error generating preview: {e}") + + def _create_template_from_form(self) -> ChatTemplate: + """Create a ChatTemplate object from form data""" + # Get stop tokens + stop_tokens_text = self.stop_tokens_text.get('1.0', tk.END).strip() + stop_tokens = [token.strip() for token in stop_tokens_text.split('\n') if token.strip()] + + return ChatTemplate( + name=self.name_var.get().strip(), + description=self.desc_var.get().strip(), + bos_token=self.template_vars["bos_token"].get(), + system_prefix=self.template_vars["system_prefix"].get(), + system_suffix=self.template_vars["system_suffix"].get(), + user_prefix=self.template_vars["user_prefix"].get(), + user_suffix=self.template_vars["user_suffix"].get(), + assistant_prefix=self.template_vars["assistant_prefix"].get(), + assistant_suffix=self.template_vars["assistant_suffix"].get(), + turn_separator=self.template_vars["turn_separator"].get(), + eos_token=self.template_vars["eos_token"].get(), + stop_tokens=stop_tokens, + add_generation_prompt=self.add_gen_prompt_var.get() + ) + + def _save_template(self): + """Save the template""" + try: + template = self._create_template_from_form() + + # Validate required fields + if not template.name: + messagebox.showerror("Error", "Template name is required") + return + + self.result = template + self.dialog.destroy() + + except Exception as e: + messagebox.showerror("Error", f"Error saving template: {e}") + + def _center_window(self): + """Center the dialog on the parent window""" + self.dialog.update_idletasks() + + # Get parent position + parent_x = self.parent.winfo_x() + parent_y = self.parent.winfo_y() + parent_width = self.parent.winfo_width() + parent_height = self.parent.winfo_height() + + # Get dialog size + dialog_width = self.dialog.winfo_width() + dialog_height = self.dialog.winfo_height() + + # Calculate position + x = parent_x + (parent_width - dialog_width) // 2 + y = parent_y + (parent_height - dialog_height) // 2 + + self.dialog.geometry(f"+{x}+{y}") + + +# Global template manager instance +_template_manager = None + +def get_template_manager() -> ChatTemplateManager: + """Get the global chat template manager""" + global _template_manager + if _template_manager is None: + _template_manager = ChatTemplateManager() + return _template_manager \ No newline at end of file diff --git a/chess_tab.py b/chess_tab.py new file mode 100644 index 0000000..157007e --- /dev/null +++ b/chess_tab.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +""" +Chess Tab for DarkHal 2.0 + +Dedicated chess interface with AI opponent, game analysis, and UCI engine integration. +""" + +import tkinter as tk +from tkinter import ttk, messagebox, scrolledtext, filedialog +import os +import sys +import json +from pathlib import Path +from datetime import datetime +from typing import Optional, Dict, Any, List + + +class ChessTab: + """Dedicated Chess tab with AI opponent and advanced chess features.""" + + def __init__(self, parent: ttk.Frame, settings_manager): + self.parent = parent + self.settings = settings_manager + self.current_model = None + + # Create main frame + self.main_frame = ttk.Frame(parent) + self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Create chess interface + self._create_chess_interface() + + def _create_chess_interface(self): + """Create the main chess interface.""" + + # Chess configuration frame + config_frame = ttk.LabelFrame(self.main_frame, text="Chess Game Configuration", padding=10) + config_frame.pack(fill=tk.X, pady=(0, 10)) + + # Configuration options + options_frame = ttk.Frame(config_frame) + options_frame.pack(fill=tk.X, pady=10) + + ttk.Label(options_frame, text="AI Difficulty:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.chess_difficulty_var = tk.StringVar(value="Medium") + ttk.Combobox(options_frame, textvariable=self.chess_difficulty_var, + values=["Easy", "Medium", "Hard", "Expert"], + state="readonly", width=15).grid(row=0, column=1, padx=10) + + ttk.Label(options_frame, text="Time Control:").grid(row=0, column=2, sticky=tk.W, padx=(20, 0)) + self.time_control_var = tk.StringVar(value="10+0") + ttk.Combobox(options_frame, textvariable=self.time_control_var, + values=["3+0", "5+0", "10+0", "15+10", "30+0"], + state="readonly", width=10).grid(row=0, column=3, padx=10) + + ttk.Label(options_frame, text="Play as:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.play_side_var = tk.StringVar(value="White") + ttk.Radiobutton(options_frame, text="White", variable=self.play_side_var, + value="White").grid(row=1, column=1, sticky=tk.W) + ttk.Radiobutton(options_frame, text="Black", variable=self.play_side_var, + value="Black").grid(row=1, column=2, sticky=tk.W) + + # Game control buttons + control_frame = ttk.Frame(config_frame) + control_frame.pack(fill=tk.X, pady=10) + + ttk.Button(control_frame, text="New Game", command=self._new_chess_game, + style="Accent.TButton").pack(side=tk.LEFT, padx=5) + ttk.Button(control_frame, text="Start Chess", command=self._start_chess).pack(side=tk.LEFT, padx=5) + ttk.Button(control_frame, text="Analyze Position", command=self._analyze_position).pack(side=tk.LEFT, padx=5) + ttk.Button(control_frame, text="Get Hint", command=self._get_hint).pack(side=tk.LEFT, padx=5) + + # Chess engine information + engine_frame = ttk.LabelFrame(self.main_frame, text="Chess Engine Information", padding=10) + engine_frame.pack(fill=tk.X, pady=(0, 10)) + + engine_info = """DarkHal Chess Engine Features: + +• UCI (Universal Chess Interface) Protocol Support +• AI-powered move analysis and generation +• Multiple difficulty levels from beginner to expert +• Position evaluation and game analysis +• Opening book and endgame tablebase support +• Real-time move suggestions and hints +• Game saving/loading in PGN format +• Integration with ChessGPT model for enhanced play + +The chess engine uses advanced AI models to provide a challenging and educational chess experience. +You can adjust the difficulty to match your skill level and use analysis features to improve your game.""" + + ttk.Label(engine_frame, text=engine_info, wraplength=600, justify=tk.LEFT).pack(anchor=tk.W) + + # Game management frame + management_frame = ttk.LabelFrame(self.main_frame, text="Game Management", padding=10) + management_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10)) + + # Game management buttons + mgmt_buttons_frame = ttk.Frame(management_frame) + mgmt_buttons_frame.pack(fill=tk.X, pady=(0, 10)) + + ttk.Button(mgmt_buttons_frame, text="Save Game", command=self._save_game).pack(side=tk.LEFT, padx=5) + ttk.Button(mgmt_buttons_frame, text="Load Game", command=self._load_game).pack(side=tk.LEFT, padx=5) + ttk.Button(mgmt_buttons_frame, text="Export PGN", command=self._export_pgn).pack(side=tk.LEFT, padx=5) + ttk.Button(mgmt_buttons_frame, text="Import PGN", command=self._import_pgn).pack(side=tk.LEFT, padx=5) + + # Game status and move history + status_frame = ttk.Frame(management_frame) + status_frame.pack(fill=tk.BOTH, expand=True) + + # Current game status + ttk.Label(status_frame, text="Game Status:", font=("Arial", 10, "bold")).pack(anchor=tk.W, pady=(0, 5)) + self.status_label = ttk.Label(status_frame, text="No game in progress", foreground="gray") + self.status_label.pack(anchor=tk.W, pady=(0, 10)) + + # Move history + ttk.Label(status_frame, text="Move History:", font=("Arial", 10, "bold")).pack(anchor=tk.W, pady=(0, 5)) + self.move_history = scrolledtext.ScrolledText(status_frame, height=8, width=50, state=tk.DISABLED) + self.move_history.pack(fill=tk.BOTH, expand=True, pady=(0, 10)) + + # Chess engine status + engine_status_frame = ttk.Frame(management_frame) + engine_status_frame.pack(fill=tk.X) + + ttk.Label(engine_status_frame, text="Engine Status:", font=("Arial", 10, "bold")).pack(side=tk.LEFT) + self.engine_status_label = ttk.Label(engine_status_frame, text="Ready", foreground="green") + self.engine_status_label.pack(side=tk.LEFT, padx=(5, 0)) + + ttk.Button(engine_status_frame, text="Configure Engine", + command=self._configure_engine).pack(side=tk.RIGHT, padx=5) + + def _start_chess(self): + """Start a chess game with AI in floating window.""" + try: + from chess_window import open_chess_window + open_chess_window(self.parent.winfo_toplevel(), self.settings) + self._update_status("Chess game started") + except ImportError: + messagebox.showerror("Error", "Chess module not available. Please install python-chess.") + except Exception as e: + messagebox.showerror("Error", f"Failed to open chess window: {e}") + + def _new_chess_game(self): + """Start a new chess game.""" + difficulty = self.chess_difficulty_var.get() + time_control = self.time_control_var.get() + play_side = self.play_side_var.get() + + self._update_status(f"New game started - {play_side} vs AI ({difficulty})") + self._add_move_to_history(f"Game started: {play_side} vs AI") + self._add_move_to_history(f"Difficulty: {difficulty}, Time: {time_control}") + + # Start the actual chess game + self._start_chess() + + def _analyze_position(self): + """Analyze current chess position.""" + self._update_status("Analyzing position...") + self._add_move_to_history("Position analysis requested") + # TODO: Implement position analysis with AI + messagebox.showinfo("Analysis", "Position analysis feature coming soon!") + + def _get_hint(self): + """Get a hint for the current position.""" + self._update_status("Generating hint...") + self._add_move_to_history("Hint requested") + # TODO: Implement hint generation with AI + messagebox.showinfo("Hint", "Hint generation feature coming soon!") + + def _save_game(self): + """Save the current game.""" + filename = filedialog.asksaveasfilename( + title="Save Chess Game", + defaultextension=".pgn", + filetypes=[("PGN files", "*.pgn"), ("All files", "*.*")] + ) + if filename: + # TODO: Implement game saving + self._add_move_to_history(f"Game saved to {filename}") + messagebox.showinfo("Saved", f"Game saved to {filename}") + + def _load_game(self): + """Load a saved game.""" + filename = filedialog.askopenfilename( + title="Load Chess Game", + filetypes=[("PGN files", "*.pgn"), ("All files", "*.*")] + ) + if filename: + # TODO: Implement game loading + self._add_move_to_history(f"Game loaded from {filename}") + self._update_status("Game loaded") + messagebox.showinfo("Loaded", f"Game loaded from {filename}") + + def _export_pgn(self): + """Export current game to PGN format.""" + filename = filedialog.asksaveasfilename( + title="Export to PGN", + defaultextension=".pgn", + filetypes=[("PGN files", "*.pgn"), ("All files", "*.*")] + ) + if filename: + # TODO: Implement PGN export + self._add_move_to_history(f"Game exported to PGN: {filename}") + messagebox.showinfo("Exported", f"Game exported to {filename}") + + def _import_pgn(self): + """Import game from PGN format.""" + filename = filedialog.askopenfilename( + title="Import PGN", + filetypes=[("PGN files", "*.pgn"), ("All files", "*.*")] + ) + if filename: + # TODO: Implement PGN import + self._add_move_to_history(f"Game imported from PGN: {filename}") + self._update_status("PGN game imported") + messagebox.showinfo("Imported", f"Game imported from {filename}") + + def _configure_engine(self): + """Configure chess engine settings.""" + # TODO: Implement engine configuration dialog + messagebox.showinfo("Engine Config", "Engine configuration coming soon!") + + def _update_status(self, status: str): + """Update the game status display.""" + self.status_label.config(text=status) + timestamp = datetime.now().strftime("%H:%M:%S") + print(f"[{timestamp}] Chess: {status}") + + def _add_move_to_history(self, move: str): + """Add a move or event to the move history.""" + self.move_history.config(state=tk.NORMAL) + timestamp = datetime.now().strftime("%H:%M:%S") + self.move_history.insert(tk.END, f"[{timestamp}] {move}\n") + self.move_history.see(tk.END) + self.move_history.config(state=tk.DISABLED) + + def set_model(self, model_path: Optional[str]): + """Set the current AI model for chess analysis.""" + self.current_model = model_path + if model_path: + self._update_status(f"AI Model loaded: {Path(model_path).name}") + else: + self._update_status("No AI model loaded") \ No newline at end of file diff --git a/chess_window.py b/chess_window.py new file mode 100644 index 0000000..8a26f27 --- /dev/null +++ b/chess_window.py @@ -0,0 +1,1210 @@ +#!/usr/bin/env python3 +""" +Floating Chess Window for DarkHal 2.0 + +A separate window for playing chess against the AI with a visual board. +""" + +import tkinter as tk +from tkinter import ttk, messagebox, filedialog, scrolledtext +import subprocess +import sys +import os +import threading +from pathlib import Path +import json +import datetime + +try: + import chess + import chess.svg +except ImportError: + chess = None + + +class ChessWindow: + """Floating chess game window.""" + + def __init__(self, parent, settings_manager): + self.parent = parent + self.settings = settings_manager + self.window = None + self.board = None + self.engine_process = None + self.selected_square = None + self.move_history = [] + self.game_saved = True + self.llm_cache = None # Cache for LLM instance + self.ai_thinking = False + self.flipped_board = False + + if not chess: + messagebox.showerror("Missing Dependency", + "Please install python-chess: pip install python-chess") + return + + self._create_window() + + def _create_window(self): + """Create the floating chess window.""" + self.window = tk.Toplevel(self.parent) + self.window.title("DarkHal Chess - Human vs AI") + self.window.geometry("800x900") + self.window.resizable(False, False) + + # Set icon + try: + icon_path = os.path.join(os.path.dirname(__file__), "assets", "Halico.ico") + if os.path.exists(icon_path): + self.window.iconbitmap(icon_path) + except Exception: + pass + + # Initialize chess board + self.board = chess.Board() + + # Create UI + self._create_ui() + + # Handle window closing + self.window.protocol("WM_DELETE_WINDOW", self._on_closing) + + def _create_ui(self): + """Create the chess UI.""" + # Title frame + title_frame = ttk.Frame(self.window) + title_frame.pack(fill=tk.X, padx=10, pady=5) + + title_label = ttk.Label(title_frame, text="DarkHal Chess Engine", + font=("Arial", 16, "bold")) + title_label.pack() + + subtitle_label = ttk.Label(title_frame, text="Human vs AI Chess Match", + font=("Arial", 10)) + subtitle_label.pack() + + # Game controls frame + controls_frame = ttk.LabelFrame(self.window, text="Game Controls", padding=10) + controls_frame.pack(fill=tk.X, padx=10, pady=5) + + # Game settings + settings_frame = ttk.Frame(controls_frame) + settings_frame.pack(fill=tk.X) + + ttk.Label(settings_frame, text="Play as:").grid(row=0, column=0, sticky=tk.W, padx=5) + self.play_side_var = tk.StringVar(value="White") + ttk.Radiobutton(settings_frame, text="White", variable=self.play_side_var, + value="White", command=self._side_changed).grid(row=0, column=1) + ttk.Radiobutton(settings_frame, text="Black", variable=self.play_side_var, + value="Black", command=self._side_changed).grid(row=0, column=2) + + ttk.Label(settings_frame, text="AI Difficulty:").grid(row=0, column=3, sticky=tk.W, padx=(20, 5)) + self.difficulty_var = tk.StringVar(value="Medium") + ttk.Combobox(settings_frame, textvariable=self.difficulty_var, + values=["Easy", "Medium", "Hard", "Expert"], + state="readonly", width=10).grid(row=0, column=4) + + # Control buttons + buttons_frame = ttk.Frame(controls_frame) + buttons_frame.pack(fill=tk.X, pady=10) + + ttk.Button(buttons_frame, text="New Game", command=self._new_game).pack(side=tk.LEFT, padx=5) + ttk.Button(buttons_frame, text="Flip Board", command=self._flip_board).pack(side=tk.LEFT, padx=5) + ttk.Button(buttons_frame, text="Hint", command=self._get_hint).pack(side=tk.LEFT, padx=5) + ttk.Button(buttons_frame, text="Undo", command=self._undo_move).pack(side=tk.LEFT, padx=5) + ttk.Button(buttons_frame, text="Analyze", command=self._analyze_position).pack(side=tk.LEFT, padx=5) + ttk.Button(buttons_frame, text="Save Game", command=self._save_game).pack(side=tk.LEFT, padx=5) + ttk.Button(buttons_frame, text="Load Game", command=self._load_game).pack(side=tk.LEFT, padx=5) + + # Game status + status_frame = ttk.Frame(controls_frame) + status_frame.pack(fill=tk.X) + + ttk.Label(status_frame, text="Status:").pack(side=tk.LEFT) + self.status_var = tk.StringVar(value="Ready to play") + self.status_label = ttk.Label(status_frame, textvariable=self.status_var, + font=("Arial", 10, "bold")) + self.status_label.pack(side=tk.LEFT, padx=10) + + ttk.Label(status_frame, text="Turn:").pack(side=tk.LEFT, padx=(20, 0)) + self.turn_var = tk.StringVar(value="White") + self.turn_label = ttk.Label(status_frame, textvariable=self.turn_var, + font=("Arial", 10, "bold")) + self.turn_label.pack(side=tk.LEFT, padx=5) + + # Chess board frame + board_frame = ttk.LabelFrame(self.window, text="Chess Board", padding=10) + board_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=5) + + # Create chess board canvas + self.board_canvas = tk.Canvas(board_frame, width=640, height=640, bg="white") + self.board_canvas.pack() + + # Bind mouse events + self.board_canvas.bind("", self._on_square_click) + self.board_canvas.bind("", self._on_mouse_motion) + + # Move history frame + history_frame = ttk.LabelFrame(self.window, text="Move History", padding=10) + history_frame.pack(fill=tk.X, padx=10, pady=5) + + # Move history text + self.history_text = tk.Text(history_frame, height=4, wrap=tk.WORD, + font=("Consolas", 9), state=tk.DISABLED) + history_scroll = ttk.Scrollbar(history_frame, orient="vertical", command=self.history_text.yview) + self.history_text.configure(yscrollcommand=history_scroll.set) + + self.history_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + history_scroll.pack(side=tk.RIGHT, fill=tk.Y) + + # Draw initial board + self._draw_board() + self._update_status() + + # Start AI if playing as black + if self.play_side_var.get() == "Black": + self._ai_move() + + def _draw_board(self): + """Draw the chess board and pieces.""" + self.board_canvas.delete("all") + + # Board dimensions + square_size = 80 + board_size = square_size * 8 + + # Colors + light_color = "#F0D9B5" + dark_color = "#B58863" + highlight_color = "#FFFF00" + + # Draw squares + for rank in range(8): + for file in range(8): + x1 = file * square_size + y1 = rank * square_size + x2 = x1 + square_size + y2 = y1 + square_size + + # Determine square color + if (rank + file) % 2 == 0: + color = light_color + else: + color = dark_color + + # Highlight selected square + square = chess.square(file, 7 - rank) + if square == self.selected_square: + color = highlight_color + + self.board_canvas.create_rectangle(x1, y1, x2, y2, fill=color, outline="black") + + # Draw coordinates + if file == 0: # Left edge - rank labels + self.board_canvas.create_text(x1 + 5, y1 + 10, text=str(8 - rank), + font=("Arial", 8), anchor="nw") + if rank == 7: # Bottom edge - file labels + self.board_canvas.create_text(x2 - 10, y2 - 5, text=chr(ord('a') + file), + font=("Arial", 8), anchor="se") + + # Draw pieces + self._draw_pieces() + + def _draw_pieces(self): + """Draw chess pieces on the board.""" + square_size = 80 + + # Unicode chess pieces + piece_symbols = { + chess.PAWN: "♟♙", chess.ROOK: "♜♖", chess.KNIGHT: "♞♘", + chess.BISHOP: "♝♗", chess.QUEEN: "♛♕", chess.KING: "♚♔" + } + + for square in chess.SQUARES: + piece = self.board.piece_at(square) + if piece: + file = chess.square_file(square) + rank = chess.square_rank(square) + + x = file * square_size + square_size // 2 + y = (7 - rank) * square_size + square_size // 2 + + # Get piece symbol + symbol = piece_symbols[piece.piece_type][0 if piece.color == chess.BLACK else 1] + + self.board_canvas.create_text(x, y, text=symbol, font=("Arial", 48), + fill="black" if piece.color == chess.BLACK else "white", + anchor="center") + + def _on_square_click(self, event): + """Handle square click events.""" + if self.board.is_game_over(): + return + + # Convert canvas coordinates to square + square_size = 80 + file = event.x // square_size + rank = 7 - (event.y // square_size) + + if 0 <= file <= 7 and 0 <= rank <= 7: + square = chess.square(file, rank) + + if self.selected_square is None: + # Select piece + piece = self.board.piece_at(square) + if piece and piece.color == self.board.turn: + # Only allow selecting our pieces on our turn + human_color = chess.WHITE if self.play_side_var.get() == "White" else chess.BLACK + if self.board.turn == human_color: + self.selected_square = square + self._draw_board() + else: + # Try to make move + try: + move = chess.Move(self.selected_square, square) + + # Check for promotion + piece = self.board.piece_at(self.selected_square) + if (piece and piece.piece_type == chess.PAWN and + ((piece.color == chess.WHITE and rank == 7) or + (piece.color == chess.BLACK and rank == 0))): + # Auto-promote to queen for simplicity + move = chess.Move(self.selected_square, square, promotion=chess.QUEEN) + + if move in self.board.legal_moves: + self._make_human_move(move) + else: + messagebox.showwarning("Illegal Move", "That move is not legal!") + + except Exception as e: + messagebox.showerror("Move Error", f"Error making move: {e}") + + self.selected_square = None + self._draw_board() + + def _on_mouse_motion(self, event): + """Handle mouse motion for hover effects.""" + # Could add hover highlighting here + pass + + def _make_human_move(self, move): + """Make a human move and trigger AI response.""" + print(f"[CHESS DEBUG] Making human move: {move.uci()}") + + # Get SAN notation BEFORE applying the move + san_notation = self.board.san(move) + print(f"[CHESS DEBUG] Human move SAN notation: {san_notation}") + + # Add move to board + self.board.push(move) + print(f"[CHESS DEBUG] Move applied to board. New position: {self.board.fen()}") + + self._add_move_to_history_with_san(move, san_notation) + self._draw_board() + self._update_status() + + # Check if game is over + if self.board.is_game_over(): + print(f"[CHESS DEBUG] Game is over after human move") + self._game_over() + return + + print(f"[CHESS DEBUG] Starting AI move in background thread") + # AI's turn + threading.Thread(target=self._ai_move, daemon=True).start() + + def _ai_move(self): + """Let AI make a move.""" + print(f"[CHESS DEBUG] AI move started") + + if self.board.is_game_over(): + print(f"[CHESS DEBUG] Game is over, AI move cancelled") + return + + # Check whose turn it is + ai_color = chess.BLACK if self.play_side_var.get() == "White" else chess.WHITE + ai_color_str = "Black" if ai_color == chess.BLACK else "White" + board_turn_str = "Black" if self.board.turn == chess.BLACK else "White" + + if self.board.turn != ai_color: + print(f"[CHESS DEBUG] Not AI's turn! Board turn: {board_turn_str}, AI color: {ai_color_str}") + return + + print(f"[CHESS DEBUG] AI's turn confirmed. Board turn: {board_turn_str}, AI color: {ai_color_str}") + + # Update status + self.window.after(0, lambda: self.status_var.set("AI is thinking...")) + + try: + # Get AI move using simple logic for now + # In a full implementation, this would call the UCI engine + ai_move = self._get_ai_move() + + if ai_move: + print(f"[CHESS DEBUG] AI found move: {ai_move.uci()}") + self.window.after(0, lambda: self._apply_ai_move(ai_move)) + else: + print(f"[CHESS DEBUG] AI couldn't find a move") + self.window.after(0, lambda: self.status_var.set("AI couldn't find a move")) + + except Exception as e: + print(f"[CHESS DEBUG] AI move exception: {e}") + error_msg = str(e) + self.window.after(0, lambda: messagebox.showerror("AI Error", f"AI move failed: {error_msg}")) + + def _get_ai_move(self): + """Get AI move using LLM integration with repetition avoidance.""" + legal_moves = list(self.board.legal_moves) + if not legal_moves: + return None + + # Filter out moves that would repeat recent positions + filtered_moves = self._filter_repetitive_moves(legal_moves) + if not filtered_moves: + filtered_moves = legal_moves # Fall back to all legal moves if filtering removes everything + + # Try LLM first + llm_move = self._query_llm_for_move() + if llm_move and llm_move in filtered_moves: + return llm_move + elif llm_move and llm_move in legal_moves: + print(f"[CHESS DEBUG] LLM suggested repetitive move: {llm_move.uci()}, using fallback") + + # Fallback to strategic heuristics with filtered moves + return self._get_strategic_move(filtered_moves) + + def _filter_repetitive_moves(self, legal_moves): + """Filter out moves that would create repetitive positions.""" + if len(self.board.move_stack) < 4: + return legal_moves # Not enough moves to check repetition + + current_fen = self.board.fen().split()[0] # Just board position + filtered_moves = [] + + for move in legal_moves: + # Test if this move would create a repetition + self.board.push(move) + new_fen = self.board.fen().split()[0] + self.board.pop() + + # Check if this position appeared recently + is_repetitive = False + temp_board = chess.Board() + recent_positions = [] + + # Build recent position history + for historical_move in self.board.move_stack[-6:]: + temp_board.push(historical_move) + recent_positions.append(temp_board.fen().split()[0]) + + # Check if new position would repeat a recent one + if new_fen in recent_positions[-4:]: # Last 2 moves + is_repetitive = True + print(f"[CHESS DEBUG] Filtering repetitive move: {move.uci()}") + + if not is_repetitive: + filtered_moves.append(move) + + return filtered_moves if filtered_moves else legal_moves + + def _query_llm_for_move(self): + """Query the LLM for a chess move using DarkHal's main chat system.""" + try: + print(f"[CHESS DEBUG] Starting LLM query for move") + + # Check if chess mode is enabled + chess_mode = self.settings.get('model_settings.chess_mode', False) + print(f"[CHESS DEBUG] Chess mode enabled: {chess_mode}") + + # Import DarkHal's main chat function + sys.path.append(os.path.dirname(os.path.dirname(__file__))) + from main import run_prompt + + # Get model path and settings + model_path = self.settings.get('paths.last_model_path', '') + if not model_path or not os.path.exists(model_path): + print(f"[CHESS DEBUG] No valid model path: {model_path}") + return None + + # Prepare the chess prompt + prompt = self._create_chess_prompt() + print(f"[CHESS DEBUG] Created chess prompt: {len(prompt)} characters") + + # Get settings for LLM + n_ctx = self.settings.get('model_settings.default_n_ctx', 4096) + n_gpu_layers = self.settings.get('model_settings.default_n_gpu_layers', 0) + + # Use multiple attempts with different temperatures + max_attempts = 3 + temperatures = [0.1, 0.3, 0.5] + + for attempt in range(max_attempts): + try: + print(f"[CHESS DEBUG] Attempt {attempt + 1} with temperature {temperatures[attempt]}") + + # Use DarkHal's run_prompt function (which returns a string directly) + response_text = run_prompt( + model_path=model_path, + prompt=prompt, + stream=False, + n_ctx=n_ctx, + n_gpu_layers=n_gpu_layers, + max_tokens=20, + chess_mode=chess_mode + ) + + print(f"[CHESS DEBUG] LLM response: '{response_text}'") + + # Parse the response text directly + move = self._parse_move_from_response(response_text) + + if move and move in self.board.legal_moves: + print(f"[CHESS DEBUG] Valid move found: {move.uci()} (attempt {attempt + 1})") + return move + elif move: + print(f"[CHESS DEBUG] Invalid move suggested: {move.uci()} not in legal moves") + else: + print(f"[CHESS DEBUG] Could not parse move from response") + + except Exception as e: + print(f"[CHESS DEBUG] LLM attempt {attempt + 1} failed: {e}") + import traceback + traceback.print_exc() + continue + + print(f"[CHESS DEBUG] All LLM attempts failed") + return None + + except Exception as e: + print(f"[CHESS DEBUG] LLM query failed: {e}") + import traceback + traceback.print_exc() + return None + + def _parse_move_from_response(self, text): + """Parse chess move from LLM response using multiple methods.""" + # Clean the text + text = text.strip().lower() + legal_moves = list(self.board.legal_moves) + legal_uci = [move.uci() for move in legal_moves] + + # Method 1: Direct UCI format match + for move_uci in legal_uci: + if move_uci in text: + return chess.Move.from_uci(move_uci) + + # Method 2: Look for 4-5 character sequences that could be UCI + import re + uci_pattern = r'\b[a-h][1-8][a-h][1-8][qrbn]?\b' + matches = re.findall(uci_pattern, text) + for match in matches: + if match in legal_uci: + return chess.Move.from_uci(match) + + # Method 3: Try to find SAN notation and convert + # Use a safe approach that doesn't call .san() on invalid moves + for move in legal_moves: + try: + # Calculate SAN for this legal move + san = self.board.san(move).lower() + san_clean = san.replace('+', '').replace('#', '').replace('x', '') + if san_clean in text or san in text: + return move + except: + # Skip if SAN calculation fails + continue + + # Method 4: Extract first plausible move-like string + tokens = text.replace(',', ' ').replace('.', ' ').split() + for token in tokens: + token = token.strip('.,()[]{}') + if len(token) >= 4 and len(token) <= 5: + try: + # Try as UCI + if token in legal_uci: + return chess.Move.from_uci(token) + except: + continue + + return None + + def _get_llm_instance(self): + """Get cached LLM instance.""" + if self.llm_cache: + return self.llm_cache + + try: + # Import from main project + sys.path.append('..') + from llama_cpp import Llama + + # Get model path from settings + model_path = self.settings.get('paths.last_model_path', '') + if not model_path or not os.path.exists(model_path): + return None + + # Create LLM instance + self.llm_cache = Llama( + model_path=model_path, + n_ctx=self.settings.get('model_settings.default_n_ctx', 4096), + n_gpu_layers=self.settings.get('model_settings.default_n_gpu_layers', 0), + verbose=False + ) + + return self.llm_cache + + except Exception as e: + print(f"Failed to load LLM: {e}") + return None + + def _create_chess_prompt(self): + """Create a chess-specific prompt for the LLM with anti-repetition measures.""" + # Get game context + fen = self.board.fen() + legal_moves = [move.uci() for move in self.board.legal_moves] + # Get move history safely - use stored SAN notation instead of converting + move_history = [] + for item in self.move_history[-10:]: + if ' ' in item: + move_history.append(item.split(' ')[0]) # Get just the SAN part + + # Determine player color + ai_color = "Black" if self.play_side_var.get() == "White" else "White" + current_turn = "White" if self.board.turn == chess.WHITE else "Black" + + # Get recent position repetitions + recent_positions = [] + temp_board = chess.Board() + for move in self.board.move_stack[-6:]: # Last 3 moves (6 half-moves) + temp_board.push(move) + recent_positions.append(temp_board.fen().split()[0]) # Just board position, not full FEN + + # Check for repetitive patterns + repetition_warning = "" + if len(recent_positions) >= 4: + if recent_positions[-1] == recent_positions[-3] and recent_positions[-2] == recent_positions[-4]: + repetition_warning = "WARNING: Position is repeating! Choose a different strategy to avoid draws." + + # Analyze game phase + material_count = len([p for p in str(self.board) if p.isalpha()]) + if material_count > 28: + game_phase = "opening" + phase_advice = "Focus on piece development, center control, and king safety." + elif material_count > 16: + game_phase = "middlegame" + phase_advice = "Look for tactical combinations and improve piece coordination." + else: + game_phase = "endgame" + phase_advice = "Activate your king, push passed pawns, and simplify advantageous positions." + + # Create enhanced strategic prompt + prompt = f"""You are an expert chess player playing as {ai_color}. It's {current_turn}'s turn. + +Position (FEN): {fen} +Game phase: {game_phase} +Recent moves: {' '.join(move_history) if move_history else 'Game start'} + +{repetition_warning} + +Available moves (UCI format): {', '.join(legal_moves[:20])} + +Strategic priorities for {game_phase}: +{phase_advice} + +Choose the BEST move considering: +1. Avoid repeating recent moves or positions +2. King safety and piece protection +3. {phase_advice.split('.')[0].lower()} +4. Tactical opportunities (captures, forks, pins, skewers) +5. Long-term strategic advantages + +IMPORTANT: Do not repeat the same move or create position repetitions. + +Respond with ONLY the move in UCI format (e.g., e2e4, g1f3, e7e8q):""" + + return prompt + + def _get_strategic_move(self, legal_moves): + """Get strategic move using chess heuristics.""" + import random + + scored_moves = [] + + for move in legal_moves: + score = 0 + + # Make move temporarily to evaluate + self.board.push(move) + + # Prefer captures with good piece values + if self.board.move_stack and self.board.is_capture(self.board.move_stack[-1]): + captured_piece = self.board.piece_at(move.to_square) + if captured_piece: + piece_values = {chess.PAWN: 1, chess.KNIGHT: 3, chess.BISHOP: 3, + chess.ROOK: 5, chess.QUEEN: 9, chess.KING: 0} + score += piece_values.get(captured_piece.piece_type, 0) * 10 + + # Prefer central squares + to_square = move.to_square + file = chess.square_file(to_square) + rank = chess.square_rank(to_square) + center_distance = abs(3.5 - file) + abs(3.5 - rank) + score += (7 - center_distance) * 2 + + # Prefer piece development + piece = self.board.piece_at(move.from_square) + if piece and piece.piece_type in [chess.KNIGHT, chess.BISHOP]: + if chess.square_rank(move.from_square) in [0, 7]: # From back rank + score += 15 + + # Avoid moving king early (unless castling) + if piece and piece.piece_type == chess.KING and not self.board.is_castling(move): + if len(self.board.move_stack) < 10: # Early game + score -= 10 + + # Prefer checks + if self.board.is_check(): + score += 5 + + # Avoid leaving pieces hanging + if self.board.is_attacked_by(not self.board.turn, move.to_square): + score -= piece_values.get(piece.piece_type if piece else chess.PAWN, 0) * 5 + + self.board.pop() # Undo temporary move + + scored_moves.append((move, score)) + + # Sort by score and add randomness for variety + scored_moves.sort(key=lambda x: x[1] + random.random() * 2, reverse=True) + return scored_moves[0][0] + + def _apply_ai_move(self, move): + """Apply AI move to the board with validation.""" + try: + print(f"[CHESS DEBUG] Applying AI move: {move.uci()}") + + # Validate move is legal + if move not in self.board.legal_moves: + print(f"[CHESS DEBUG] Illegal AI move attempted: {move.uci()}") + return False + + # Get SAN notation BEFORE applying the move + san_notation = self.board.san(move) + print(f"[CHESS DEBUG] Move SAN notation: {san_notation}") + + # Apply move + self.board.push(move) + print(f"[CHESS DEBUG] Move applied to board") + + # Add to history using the pre-calculated SAN + self._add_move_to_history_with_san(move, san_notation) + self._draw_board() + self._update_status() + + print(f"[CHESS DEBUG] AI played: {san_notation} ({move.uci()})") + + if self.board.is_game_over(): + print(f"[CHESS DEBUG] Game over after AI move") + self._game_over() + + return True + + except Exception as e: + print(f"[CHESS DEBUG] Error applying AI move: {e}") + import traceback + traceback.print_exc() + return False + + def _add_move_to_history(self, move): + """Add move to history display (deprecated - use _add_move_to_history_with_san).""" + # This method is kept for compatibility but should not be used + # as it can cause the SAN notation error + print(f"[CHESS DEBUG] WARNING: Using deprecated _add_move_to_history method") + try: + san_notation = self.board.san(move) + self._add_move_to_history_with_san(move, san_notation) + except Exception as e: + print(f"[CHESS DEBUG] Error in deprecated history method: {e}") + # Fallback to just UCI notation + self._add_move_to_history_with_san(move, move.uci()) + + def _add_move_to_history_with_san(self, move, san_notation): + """Add move to history display using pre-calculated SAN notation.""" + print(f"[CHESS DEBUG] Adding move to history: {san_notation} ({move.uci()})") + + self.history_text.config(state=tk.NORMAL) + + move_num = len(self.board.move_stack) // 2 + 1 + if len(self.board.move_stack) % 2 == 1: # White move (odd number in stack) + self.history_text.insert(tk.END, f"{move_num}. {san_notation} ") + else: # Black move (even number in stack) + self.history_text.insert(tk.END, f"{san_notation}\n") + + self.history_text.see(tk.END) + self.history_text.config(state=tk.DISABLED) + + def _update_status(self): + """Update game status display.""" + if self.board.is_game_over(): + if self.board.is_checkmate(): + winner = "White" if self.board.turn == chess.BLACK else "Black" + self.status_var.set(f"Checkmate! {winner} wins!") + elif self.board.is_stalemate(): + self.status_var.set("Stalemate - Draw!") + elif self.board.is_insufficient_material(): + self.status_var.set("Draw - Insufficient material") + else: + self.status_var.set("Game Over") + elif self.board.is_check(): + self.status_var.set("Check!") + else: + self.status_var.set("Game in progress") + + # Update turn + self.turn_var.set("White" if self.board.turn == chess.WHITE else "Black") + + def _game_over(self): + """Handle game over.""" + result = "Unknown" + if self.board.is_checkmate(): + winner = "White" if self.board.turn == chess.BLACK else "Black" + result = f"{winner} wins by checkmate!" + elif self.board.is_stalemate(): + result = "Draw by stalemate!" + elif self.board.is_insufficient_material(): + result = "Draw by insufficient material!" + + messagebox.showinfo("Game Over", result) + + def _new_game(self): + """Start a new game.""" + if not self.game_saved: + result = messagebox.askyesnocancel("Unsaved Game", + "Current game is not saved. Save before starting new game?") + if result is True: # Yes - save first + if not self._save_game(): + return # Save cancelled + elif result is None: # Cancel + return + + self.board = chess.Board() + self.selected_square = None + self.move_history = [] + self.game_saved = True + + # Clear history display + self.history_text.config(state=tk.NORMAL) + self.history_text.delete(1.0, tk.END) + self.history_text.config(state=tk.DISABLED) + + self._draw_board() + self._update_status() + + # AI goes first if human plays black + if self.play_side_var.get() == "Black": + threading.Thread(target=self._ai_move, daemon=True).start() + + def _flip_board(self): + """Flip the board view.""" + self.flipped_board = not self.flipped_board + self._draw_board() + messagebox.showinfo("Board Flipped", f"Board view {'flipped' if self.flipped_board else 'normal'}") + + def _get_hint(self): + """Get a hint for the current position using AI analysis.""" + if self.board.is_game_over(): + messagebox.showinfo("Hint", "Game is over!") + return + + if self.ai_thinking: + messagebox.showinfo("Hint", "AI is currently thinking. Please wait.") + return + + # Check if it's human's turn + human_color = chess.WHITE if self.play_side_var.get() == "White" else chess.BLACK + if self.board.turn != human_color: + messagebox.showinfo("Hint", "It's not your turn!") + return + + # Get AI suggestion + self.status_var.set("Analyzing position for hint...") + + def get_hint_thread(): + try: + # Use the same AI logic to get best move + hint_move = self._get_ai_move() + + if hint_move: + from_square = chess.square_name(hint_move.from_square) + to_square = chess.square_name(hint_move.to_square) + + # Get piece name + piece = self.board.piece_at(hint_move.from_square) + piece_name = piece.symbol().upper() if piece else "Piece" + + # Check if it's a special move + move_type = "" + if self.board.is_capture(hint_move): + move_type += " (Capture)" + if self.board.is_castling(hint_move): + move_type += " (Castling)" + if hint_move.promotion: + move_type += f" (Promote to {chess.piece_name(hint_move.promotion)})" + + hint_text = f"Suggested move: {piece_name} from {from_square} to {to_square}{move_type}\n\n" + hint_text += f"Move notation: {self.board.san(hint_move)}\n" + hint_text += f"UCI format: {hint_move.uci()}" + + self.window.after(0, lambda: messagebox.showinfo("Chess Hint", hint_text)) + else: + self.window.after(0, lambda: messagebox.showinfo("Hint", "No good moves found!")) + + self.window.after(0, lambda: self.status_var.set("Ready")) + + except Exception as e: + self.window.after(0, lambda: messagebox.showerror("Hint Error", f"Failed to get hint: {e}")) + self.window.after(0, lambda: self.status_var.set("Ready")) + + threading.Thread(target=get_hint_thread, daemon=True).start() + + def _undo_move(self): + """Undo the last move(s).""" + if self.ai_thinking: + messagebox.showinfo("Undo", "Cannot undo while AI is thinking.") + return + + if len(self.board.move_stack) >= 2: + # Undo both AI and human moves + self.board.pop() + self.board.pop() + self.move_history = self.move_history[:-2] + self.game_saved = False + + # Update history display + self.history_text.config(state=tk.NORMAL) + self.history_text.delete(1.0, tk.END) + + # Rebuild history display + for i, move in enumerate(self.move_history): + move_num = (i + 2) // 2 + if i % 2 == 0: # White move + self.history_text.insert(tk.END, f"{move_num}. {move} ") + else: # Black move + self.history_text.insert(tk.END, f"{move}\n") + + self.history_text.config(state=tk.DISABLED) + + self._draw_board() + self._update_status() + + elif len(self.board.move_stack) == 1: + # Undo only one move + self.board.pop() + self.move_history = self.move_history[:-1] + self.game_saved = False + + self.history_text.config(state=tk.NORMAL) + self.history_text.delete(1.0, tk.END) + self.history_text.config(state=tk.DISABLED) + + self._draw_board() + self._update_status() + else: + messagebox.showinfo("Undo", "No moves to undo!") + + def _analyze_position(self): + """Analyze current position with detailed information.""" + if self.ai_thinking: + messagebox.showinfo("Analysis", "AI is currently thinking. Please wait.") + return + + # Create analysis window + analysis_window = tk.Toplevel(self.window) + analysis_window.title("Position Analysis") + analysis_window.geometry("600x500") + analysis_window.transient(self.window) + + # Create notebook for different analysis types + notebook = ttk.Notebook(analysis_window) + notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Basic analysis tab + basic_frame = ttk.Frame(notebook) + notebook.add(basic_frame, text="Basic Info") + + basic_text = tk.Text(basic_frame, wrap=tk.WORD, font=("Consolas", 10)) + basic_scroll = ttk.Scrollbar(basic_frame, orient="vertical", command=basic_text.yview) + basic_text.configure(yscrollcommand=basic_scroll.set) + + basic_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + basic_scroll.pack(side=tk.RIGHT, fill=tk.Y) + + # Populate basic analysis + analysis = f"Position Analysis\n{'='*50}\n\n" + analysis += f"FEN: {self.board.fen()}\n\n" + analysis += f"Turn: {'White' if self.board.turn == chess.WHITE else 'Black'}\n" + analysis += f"Move number: {self.board.fullmove_number}\n" + analysis += f"Half-move clock: {self.board.halfmove_clock}\n\n" + + analysis += f"Legal moves: {len(list(self.board.legal_moves))}\n" + analysis += f"In check: {'Yes' if self.board.is_check() else 'No'}\n" + analysis += f"Can castle kingside: {self.board.has_kingside_castling_rights(self.board.turn)}\n" + analysis += f"Can castle queenside: {self.board.has_queenside_castling_rights(self.board.turn)}\n\n" + + # Material count + analysis += "Material Count:\n" + for color in [chess.WHITE, chess.BLACK]: + color_name = "White" if color == chess.WHITE else "Black" + analysis += f"\n{color_name}:\n" + for piece_type in [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN]: + count = len(self.board.pieces(piece_type, color)) + if count > 0: + analysis += f" {chess.piece_name(piece_type).title()}s: {count}\n" + + basic_text.insert(1.0, analysis) + basic_text.config(state=tk.DISABLED) + + # AI Analysis tab + ai_frame = ttk.Frame(notebook) + notebook.add(ai_frame, text="AI Analysis") + + ai_text = scrolledtext.ScrolledText(ai_frame, wrap=tk.WORD, font=("Consolas", 10)) + ai_text.pack(fill=tk.BOTH, expand=True) + + # Get AI analysis + def get_ai_analysis(): + ai_text.insert(tk.END, "Getting AI analysis...\n\n") + try: + llm = self._get_llm_instance() + if llm: + analysis_prompt = f"""Analyze this chess position as an expert player: + +Position (FEN): {self.board.fen()} +Turn: {'White' if self.board.turn == chess.WHITE else 'Black'} +Recent moves: {' '.join([self.board.san(move) for move in self.board.move_stack[-5:]])} + +Provide analysis covering: +1. Position evaluation (who's better and why) +2. Key tactical and positional themes +3. Best moves for the current player +4. Strategic plans for both sides +5. Critical weaknesses to address + +Analysis:""" + + # Use consistent temperature based on analysis depth + response = llm(analysis_prompt, max_tokens=500, temperature=0.4) + ai_analysis = response['choices'][0]['text'].strip() + + ai_text.delete(1.0, tk.END) + ai_text.insert(1.0, ai_analysis) + else: + ai_text.delete(1.0, tk.END) + ai_text.insert(1.0, "AI analysis not available - no model loaded") + + except Exception as e: + ai_text.delete(1.0, tk.END) + ai_text.insert(1.0, f"AI analysis failed: {e}") + + threading.Thread(target=get_ai_analysis, daemon=True).start() + + def _save_game(self): + """Save the current game to a PGN file.""" + if len(self.board.move_stack) == 0: + messagebox.showinfo("Save Game", "No moves to save!") + return False + + try: + # Ask for save location + filename = filedialog.asksaveasfilename( + title="Save Chess Game", + defaultextension=".pgn", + filetypes=[("PGN files", "*.pgn"), ("All files", "*.*")], + initialname=f"darkhal_chess_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.pgn" + ) + + if not filename: + return False + + # Create PGN content + pgn_content = self._create_pgn() + + # Write to file + with open(filename, 'w', encoding='utf-8') as f: + f.write(pgn_content) + + self.game_saved = True + messagebox.showinfo("Game Saved", f"Game saved to {os.path.basename(filename)}") + return True + + except Exception as e: + messagebox.showerror("Save Error", f"Failed to save game: {e}") + return False + + def _load_game(self): + """Load a game from a PGN file.""" + if not self.game_saved: + result = messagebox.askyesnocancel("Unsaved Game", + "Current game is not saved. Save before loading?") + if result is True: # Yes - save first + if not self._save_game(): + return # Save cancelled + elif result is None: # Cancel + return + + try: + # Ask for file to load + filename = filedialog.askopenfilename( + title="Load Chess Game", + filetypes=[("PGN files", "*.pgn"), ("All files", "*.*")] + ) + + if not filename: + return + + # Read PGN file + with open(filename, 'r', encoding='utf-8') as f: + pgn_content = f.read() + + # Parse and load game + if self._parse_pgn(pgn_content): + self.game_saved = True + messagebox.showinfo("Game Loaded", f"Game loaded from {os.path.basename(filename)}") + else: + messagebox.showerror("Load Error", "Invalid PGN format or unsupported game") + + except Exception as e: + messagebox.showerror("Load Error", f"Failed to load game: {e}") + + def _create_pgn(self): + """Create PGN content from current game.""" + # PGN headers + headers = [ + '[Event "DarkHal Chess Game"]', + f'[Date "{datetime.datetime.now().strftime("%Y.%m.%d")}"]', + '[White "Human"]' if self.play_side_var.get() == "White" else '[White "DarkHal AI"]', + '[Black "DarkHal AI"]' if self.play_side_var.get() == "White" else '[Black "Human"]', + f'[Site "DarkHal 2.0"]', + '[Round "1"]' + ] + + # Game result + if self.board.is_game_over(): + if self.board.is_checkmate(): + result = "1-0" if self.board.turn == chess.BLACK else "0-1" + else: + result = "1/2-1/2" + else: + result = "*" + + headers.append(f'[Result "{result}"]') + + # Create moves section + moves = [] + temp_board = chess.Board() + + for i, move in enumerate(self.board.move_stack): + if i % 2 == 0: # White move + move_num = (i // 2) + 1 + moves.append(f"{move_num}. {temp_board.san(move)}") + else: # Black move + moves.append(temp_board.san(move)) + temp_board.push(move) + + moves_text = " ".join(moves) + if result != "*": + moves_text += f" {result}" + + # Combine headers and moves + pgn = "\n".join(headers) + "\n\n" + moves_text + "\n" + return pgn + + def _parse_pgn(self, pgn_content): + """Parse PGN content and load the game.""" + try: + # Simple PGN parser - extract moves section + lines = pgn_content.strip().split('\n') + moves_text = "" + + # Find the moves section (after headers) + in_moves = False + for line in lines: + line = line.strip() + if not line: + continue + if line.startswith('['): + continue + else: + moves_text += line + " " + in_moves = True + + if not moves_text: + return False + + # Clean up moves text + moves_text = moves_text.replace('\n', ' ').strip() + + # Remove result markers + for result in ['1-0', '0-1', '1/2-1/2', '*']: + moves_text = moves_text.replace(result, '').strip() + + # Parse moves + self.board = chess.Board() + self.history_text.config(state=tk.NORMAL) + self.history_text.delete(1.0, tk.END) + self.history_text.config(state=tk.DISABLED) + + # Split into tokens and process + tokens = moves_text.split() + for token in tokens: + token = token.strip('.') + if not token or token.isdigit(): + continue + + try: + # Try to parse as SAN (Standard Algebraic Notation) + move = self.board.parse_san(token) + # Get SAN notation before applying move + san_notation = self.board.san(move) + self.board.push(move) + self._add_move_to_history_with_san(move, san_notation) + except: + # Skip invalid moves + continue + + self._draw_board() + self._update_status() + return True + + except Exception as e: + print(f"PGN parsing error: {e}") + return False + + def _side_changed(self): + """Handle play side change.""" + if hasattr(self, 'board') and self.board: + # Reset game when side changes + self._new_game() + + def _on_closing(self): + """Handle window closing.""" + if self.engine_process: + try: + self.engine_process.terminate() + except: + pass + + self.window.destroy() + + def _on_difficulty_changed(self, event=None): + """Handle difficulty setting change.""" + difficulty = self.difficulty_var.get() + self.status_var.set(f"AI difficulty set to {difficulty}") + + # Clear LLM cache to ensure new difficulty settings take effect + self.llm_cache = None + + +def open_chess_window(parent, settings_manager): + """Open the floating chess window.""" + ChessWindow(parent, settings_manager) \ No newline at end of file diff --git a/convert_icon.py b/convert_icon.py new file mode 100644 index 0000000..0b03013 --- /dev/null +++ b/convert_icon.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +""" +Icon Converter Utility +Converts PNG images to ICO format for Windows applications +""" + +import os +import sys +from pathlib import Path + +try: + from PIL import Image, ImageDraw +except ImportError: + print("Pillow is required for image conversion. Installing...") + import subprocess + subprocess.check_call([sys.executable, "-m", "pip", "install", "Pillow"]) + from PIL import Image, ImageDraw + + +def create_default_icon(): + """Create a default LLM_Train icon if no PNG is provided.""" + # Create a 256x256 icon with gradient background + size = 256 + image = Image.new('RGBA', (size, size), (0, 0, 0, 0)) + draw = ImageDraw.Draw(image) + + # Create gradient background + for i in range(size): + alpha = int(255 * (i / size)) + color = (45, 85, 135, alpha) # Blue gradient + draw.line([(0, i), (size, i)], fill=color) + + # Draw "LLM" text-like representation + center = size // 2 + + # Draw stylized "LLM" using rectangles + # L + draw.rectangle([40, 60, 60, 180], fill=(255, 255, 255, 255)) + draw.rectangle([40, 160, 100, 180], fill=(255, 255, 255, 255)) + + # L + draw.rectangle([110, 60, 130, 180], fill=(255, 255, 255, 255)) + draw.rectangle([110, 160, 170, 180], fill=(255, 255, 255, 255)) + + # M + draw.rectangle([180, 60, 200, 180], fill=(255, 255, 255, 255)) + draw.rectangle([200, 60, 220, 100], fill=(255, 255, 255, 255)) + draw.rectangle([210, 100, 230, 120], fill=(255, 255, 255, 255)) + draw.rectangle([220, 120, 240, 180], fill=(255, 255, 255, 255)) + + # Add subtle border + draw.rectangle([10, 10, size-10, size-10], outline=(255, 255, 255, 128), width=2) + + return image + + +def png_to_ico(png_path, ico_path=None, sizes=None): + """ + Convert PNG to ICO format with multiple sizes. + + Args: + png_path (str): Path to the input PNG file + ico_path (str): Path for the output ICO file (optional) + sizes (list): List of sizes to include in ICO (default: [16, 32, 48, 64, 128, 256]) + """ + if sizes is None: + sizes = [16, 32, 48, 64, 128, 256] + + png_path = Path(png_path) + + if not png_path.exists(): + print(f"PNG file not found: {png_path}") + print("Creating default icon instead...") + image = create_default_icon() + else: + try: + image = Image.open(png_path) + print(f"Loaded PNG: {png_path} ({image.size})") + except Exception as e: + print(f"Error loading PNG: {e}") + print("Creating default icon instead...") + image = create_default_icon() + + # Convert to RGBA if not already + if image.mode != 'RGBA': + image = image.convert('RGBA') + + # Create different sizes + icon_images = [] + for size in sizes: + # Resize image maintaining aspect ratio + resized = image.resize((size, size), Image.Resampling.LANCZOS) + icon_images.append(resized) + print(f"Created {size}x{size} icon") + + # Determine output path + if ico_path is None: + ico_path = png_path.with_suffix('.ico') + else: + ico_path = Path(ico_path) + + # Save as ICO + try: + icon_images[0].save( + ico_path, + format='ICO', + sizes=[(img.width, img.height) for img in icon_images], + append_images=icon_images[1:] + ) + print(f"Successfully created ICO: {ico_path}") + return str(ico_path) + except Exception as e: + print(f"Error creating ICO: {e}") + return None + + +def convert_assets_folder(assets_dir="assets"): + """Convert all PNG files in assets folder to ICO format.""" + assets_path = Path(assets_dir) + + if not assets_path.exists(): + print(f"Assets directory not found: {assets_path}") + return [] + + png_files = list(assets_path.glob("*.png")) + + if not png_files: + print(f"No PNG files found in {assets_path}") + print("Creating default application icon...") + + # Create default icon + default_ico = assets_path / "llm_train.ico" + png_to_ico("nonexistent.png", default_ico) + return [str(default_ico)] + + converted_files = [] + for png_file in png_files: + ico_file = png_file.with_suffix('.ico') + result = png_to_ico(png_file, ico_file) + if result: + converted_files.append(result) + + return converted_files + + +def main(): + """Main function for command line usage.""" + import argparse + + parser = argparse.ArgumentParser(description="Convert PNG images to ICO format") + parser.add_argument("input", nargs="?", help="Input PNG file or assets directory") + parser.add_argument("-o", "--output", help="Output ICO file") + parser.add_argument("--sizes", nargs="+", type=int, + default=[16, 32, 48, 64, 128, 256], + help="Icon sizes to include") + parser.add_argument("--assets", action="store_true", + help="Convert all PNG files in assets directory") + + args = parser.parse_args() + + if args.assets or not args.input: + print("Converting assets folder...") + converted = convert_assets_folder("assets") + if converted: + print(f"Converted {len(converted)} files:") + for file in converted: + print(f" - {file}") + else: + print("No files converted.") + else: + if Path(args.input).is_file(): + result = png_to_ico(args.input, args.output, args.sizes) + if result: + print(f"Conversion successful: {result}") + else: + print("Conversion failed.") + else: + print(f"Input file not found: {args.input}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dark_agent.py b/dark_agent.py new file mode 100644 index 0000000..80305b1 --- /dev/null +++ b/dark_agent.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +""" +Dark Agent (DarkAgent) Tab Implementation +Handles the Dark Agent interface and controls for DarkHal 2.0 +""" + +import tkinter as tk +from tkinter import ttk, messagebox +from pathlib import Path + +class DarkAgentTab: + """Dark Agent tab with agent control and configuration.""" + + def __init__(self, parent: ttk.Frame, settings_manager, main_app=None): + self.parent = parent + self.settings = settings_manager + self.main_app = main_app + + # Initialize Dark Agent integration + self._hal_integration = None + + # Create main frame + self.main_frame = ttk.Frame(parent) + self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Create Dark Agent interface + self._create_dark_agent_interface() + + def _create_dark_agent_interface(self): + """Create Dark Agent control and configuration interface.""" + + # Dark configuration frame + config_frame = ttk.LabelFrame(self.main_frame, text="Dark Agent Configuration", padding=10) + config_frame.pack(fill=tk.X, pady=(0, 10)) + + # Configuration options + options_frame = ttk.Frame(config_frame) + options_frame.pack(fill=tk.X, pady=10) + + # Agent Name (fixed, not editable) + ttk.Label(options_frame, text="Agent Name:").grid(row=0, column=0, sticky=tk.W, pady=5) + ttk.Label(options_frame, text="Dhal", font=("Arial", 9, "bold")).grid(row=0, column=1, sticky=tk.W, padx=10) + + # Model Configuration + ttk.Label(options_frame, text="Model:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.hal_model_var = tk.StringVar(value="local-llm") + ttk.Entry(options_frame, textvariable=self.hal_model_var, width=20).grid(row=1, column=1, padx=10) + + # System Message + ttk.Label(options_frame, text="System Message:").grid(row=0, column=2, sticky=tk.W, padx=(20, 0)) + self.hal_system_var = tk.StringVar(value="You are Dhal, an advanced AI assistant integrated into DarkHal 2.0.") + ttk.Entry(options_frame, textvariable=self.hal_system_var, width=40).grid(row=0, column=3, padx=10) + + # Temperature + ttk.Label(options_frame, text="Temperature:").grid(row=1, column=2, sticky=tk.W, padx=(20, 0)) + self.hal_temp_var = tk.StringVar(value="0.7") + ttk.Entry(options_frame, textvariable=self.hal_temp_var, width=10).grid(row=1, column=3, padx=10) + + # Max Tokens + ttk.Label(options_frame, text="Max Tokens:").grid(row=2, column=0, sticky=tk.W, pady=5) + self.hal_max_tokens_var = tk.StringVar(value="2048") + ttk.Entry(options_frame, textvariable=self.hal_max_tokens_var, width=10).grid(row=2, column=1, padx=10) + + # Tool Configuration + tools_frame = ttk.LabelFrame(config_frame, text="Available Tools", padding=5) + tools_frame.pack(fill=tk.X, pady=10) + + self.hal_tools = {} + tools = ["Web Search", "Code Execution", "File Operations", "System Commands"] + for i, tool in enumerate(tools): + var = tk.BooleanVar(value=True) + self.hal_tools[tool.lower().replace(" ", "_")] = var + ttk.Checkbutton(tools_frame, text=tool, variable=var).grid(row=i//2, column=i%2, sticky=tk.W, padx=10, pady=2) + + # Control buttons + control_frame = ttk.Frame(config_frame) + control_frame.pack(fill=tk.X, pady=10) + + self.hal_start_btn = ttk.Button(control_frame, text="Start Dhal Agent", + command=self._start_hal) + self.hal_start_btn.pack(side=tk.LEFT, padx=5) + + self.hal_stop_btn = ttk.Button(control_frame, text="Stop Agent", + command=self._stop_hal, state="disabled") + self.hal_stop_btn.pack(side=tk.LEFT, padx=5) + + self.hal_reset_btn = ttk.Button(control_frame, text="Reset Conversation", + command=self._reset_hal, state="disabled") + self.hal_reset_btn.pack(side=tk.LEFT, padx=5) + + # Configuration buttons + config_btn_frame = ttk.Frame(control_frame) + config_btn_frame.pack(side=tk.RIGHT, padx=5) + + ttk.Button(config_btn_frame, text="Save Config", + command=self._save_hal_config).pack(side=tk.LEFT, padx=2) + ttk.Button(config_btn_frame, text="Load Config", + command=self._load_hal_config).pack(side=tk.LEFT, padx=2) + + # Chat interface + chat_frame = ttk.LabelFrame(self.main_frame, text="Dark Agent Chat", padding=10) + chat_frame.pack(fill=tk.BOTH, expand=True, pady=10) + + # Output area + self.hal_output = tk.Text(chat_frame, height=15, wrap=tk.WORD) + hal_scrollbar = ttk.Scrollbar(chat_frame, orient=tk.VERTICAL, command=self.hal_output.yview) + self.hal_output.configure(yscrollcommand=hal_scrollbar.set) + self.hal_output.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + hal_scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # Input area + input_frame = ttk.Frame(chat_frame) + input_frame.pack(fill=tk.X, pady=(10, 0)) + + self.hal_input_var = tk.StringVar() + hal_input = ttk.Entry(input_frame, textvariable=self.hal_input_var) + hal_input.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 10)) + hal_input.bind("", lambda e: self._send_hal_message()) + + self.hal_send_btn = ttk.Button(input_frame, text="Send", + command=self._send_hal_message, state="disabled") + self.hal_send_btn.pack(side=tk.RIGHT, padx=5) + + # Status bar + self.hal_status_var = tk.StringVar(value="Dhal Status: Not Started") + ttk.Label(chat_frame, textvariable=self.hal_status_var).pack(anchor=tk.W, pady=(5, 0)) + + # Dark Agent Methods + def _start_hal(self): + """Start Dark agent.""" + if not self._hal_integration: + try: + from agent_dhal_integration import DhalAgentIntegration + self._hal_integration = DhalAgentIntegration(self) + except ImportError: + messagebox.showerror("Import Error", "Could not import DhalAgentIntegration") + return + self._hal_integration.start_agent() + + def _stop_hal(self): + """Stop Dark agent.""" + if self._hal_integration: + self._hal_integration.stop_agent() + + def _send_hal_message(self): + """Send message to Dark agent.""" + if self._hal_integration: + self._hal_integration.send_message() + + def _reset_hal(self): + """Reset Dark agent conversation.""" + if self._hal_integration: + self._hal_integration.reset_conversation() + + def _save_hal_config(self): + """Save Dark agent configuration.""" + if not self._hal_integration: + try: + from agent_dhal_integration import DhalAgentIntegration + self._hal_integration = DhalAgentIntegration(self) + except ImportError: + messagebox.showerror("Import Error", "Could not import DhalAgentIntegration") + return + self._hal_integration.save_config() + + def _load_hal_config(self): + """Load Dark agent configuration.""" + if not self._hal_integration: + try: + from agent_dhal_integration import DhalAgentIntegration + self._hal_integration = DhalAgentIntegration(self) + except ImportError: + messagebox.showerror("Import Error", "Could not import DhalAgentIntegration") + return + self._hal_integration.load_config() \ No newline at end of file diff --git a/download_manager.py b/download_manager.py new file mode 100644 index 0000000..6749aa4 --- /dev/null +++ b/download_manager.py @@ -0,0 +1,582 @@ +import os +import requests +import threading +import time +import tkinter as tk +from tkinter import ttk, messagebox +from typing import Optional, Dict, Any, List, Callable +from dataclasses import dataclass +from enum import Enum +import json +from datetime import datetime +import queue + + +class DownloadStatus(Enum): + """Download status enumeration.""" + QUEUED = "Queued" + DOWNLOADING = "Downloading" + PAUSED = "Paused" + COMPLETED = "Completed" + FAILED = "Failed" + CANCELLED = "Cancelled" + AUTH_REQUIRED = "Auth Required" + + +@dataclass +class DownloadItem: + """Represents a download item.""" + id: str + repo_id: str + filename: str + url: str + save_path: str + total_size: int = 0 + downloaded_size: int = 0 + status: DownloadStatus = DownloadStatus.QUEUED + error_message: str = "" + start_time: Optional[float] = None + end_time: Optional[float] = None + speed: float = 0.0 + eta: int = 0 + headers: Dict[str, str] = None + resume_position: int = 0 + + def __post_init__(self): + if self.headers is None: + self.headers = {} + + @property + def progress(self) -> float: + """Calculate download progress percentage.""" + if self.total_size > 0: + return (self.downloaded_size / self.total_size) * 100 + return 0.0 + + @property + def is_resumable(self) -> bool: + """Check if download can be resumed.""" + return self.status in [DownloadStatus.PAUSED, DownloadStatus.FAILED] + + +class DownloadManager: + """Manages multiple downloads with pause/resume support.""" + + def __init__(self, max_concurrent: int = 3): + self.downloads: Dict[str, DownloadItem] = {} + self.download_queue: queue.Queue = queue.Queue() + self.active_downloads: Dict[str, threading.Thread] = {} + self.max_concurrent = max_concurrent + self.callbacks: Dict[str, List[Callable]] = { + 'on_progress': [], + 'on_status_change': [], + 'on_complete': [], + 'on_error': [] + } + self._stop_flags: Dict[str, threading.Event] = {} + self._pause_flags: Dict[str, threading.Event] = {} + self._worker_thread = threading.Thread(target=self._process_queue, daemon=True) + self._worker_thread.start() + + def add_download(self, repo_id: str, filename: str, url: str, save_path: str, + headers: Optional[Dict[str, str]] = None) -> str: + """Add a new download to the queue.""" + download_id = f"{repo_id}_{filename}_{int(time.time())}" + + # Create download item + item = DownloadItem( + id=download_id, + repo_id=repo_id, + filename=filename, + url=url, + save_path=save_path, + headers=headers or {} + ) + + self.downloads[download_id] = item + self.download_queue.put(download_id) + self._trigger_callback('on_status_change', item) + + return download_id + + def pause_download(self, download_id: str): + """Pause a download.""" + if download_id in self._pause_flags: + self._pause_flags[download_id].set() + if download_id in self.downloads: + self.downloads[download_id].status = DownloadStatus.PAUSED + self._trigger_callback('on_status_change', self.downloads[download_id]) + + def resume_download(self, download_id: str): + """Resume a paused download.""" + item = self.downloads.get(download_id) + if item and item.is_resumable: + item.status = DownloadStatus.QUEUED + item.resume_position = item.downloaded_size + self.download_queue.put(download_id) + self._trigger_callback('on_status_change', item) + + def cancel_download(self, download_id: str): + """Cancel a download.""" + if download_id in self._stop_flags: + self._stop_flags[download_id].set() + + if download_id in self.downloads: + self.downloads[download_id].status = DownloadStatus.CANCELLED + self._trigger_callback('on_status_change', self.downloads[download_id]) + + # Remove partial file + item = self.downloads[download_id] + if os.path.exists(item.save_path): + try: + os.remove(item.save_path) + except Exception: + pass + + def retry_download(self, download_id: str): + """Retry a failed download.""" + item = self.downloads.get(download_id) + if item and item.status in [DownloadStatus.FAILED, DownloadStatus.AUTH_REQUIRED]: + item.status = DownloadStatus.QUEUED + item.downloaded_size = 0 + item.resume_position = 0 + item.error_message = "" + self.download_queue.put(download_id) + self._trigger_callback('on_status_change', item) + + def _process_queue(self): + """Process download queue.""" + while True: + # Check if we can start more downloads + if len(self.active_downloads) < self.max_concurrent: + try: + download_id = self.download_queue.get(timeout=1) + if download_id in self.downloads: + thread = threading.Thread( + target=self._download_file, + args=(download_id,), + daemon=True + ) + self.active_downloads[download_id] = thread + thread.start() + except queue.Empty: + pass + + # Clean up finished downloads + finished = [] + for download_id, thread in self.active_downloads.items(): + if not thread.is_alive(): + finished.append(download_id) + + for download_id in finished: + del self.active_downloads[download_id] + + time.sleep(0.5) + + def _download_file(self, download_id: str): + """Download a file with resume support.""" + item = self.downloads[download_id] + + # Create flags for this download + self._stop_flags[download_id] = threading.Event() + self._pause_flags[download_id] = threading.Event() + + try: + # Update status + item.status = DownloadStatus.DOWNLOADING + item.start_time = time.time() + self._trigger_callback('on_status_change', item) + + # Create directory if needed + os.makedirs(os.path.dirname(item.save_path), exist_ok=True) + + # Setup headers for resume + headers = item.headers.copy() + if item.resume_position > 0: + headers['Range'] = f'bytes={item.resume_position}-' + + # Make request + response = requests.get(item.url, headers=headers, stream=True, timeout=30) + + # Check for authentication issues + if response.status_code == 401 or response.status_code == 403: + item.status = DownloadStatus.AUTH_REQUIRED + item.error_message = f"Authentication failed: {response.status_code}" + self._trigger_callback('on_error', item) + return + + response.raise_for_status() + + # Get total size + if item.resume_position == 0: + item.total_size = int(response.headers.get('content-length', 0)) + + # Open file for writing (append if resuming) + mode = 'ab' if item.resume_position > 0 else 'wb' + + with open(item.save_path, mode) as f: + chunk_size = 8192 + last_update = time.time() + bytes_since_update = 0 + + for chunk in response.iter_content(chunk_size=chunk_size): + # Check stop flag + if self._stop_flags[download_id].is_set(): + return + + # Check pause flag + if self._pause_flags[download_id].is_set(): + item.status = DownloadStatus.PAUSED + self._trigger_callback('on_status_change', item) + return + + if chunk: + f.write(chunk) + item.downloaded_size += len(chunk) + bytes_since_update += len(chunk) + + # Calculate speed and ETA + current_time = time.time() + time_diff = current_time - last_update + + if time_diff >= 1.0: # Update every second + item.speed = bytes_since_update / time_diff + + if item.speed > 0: + remaining = item.total_size - item.downloaded_size + item.eta = int(remaining / item.speed) + + self._trigger_callback('on_progress', item) + last_update = current_time + bytes_since_update = 0 + + # Download completed + item.status = DownloadStatus.COMPLETED + item.end_time = time.time() + self._trigger_callback('on_complete', item) + + except requests.exceptions.RequestException as e: + item.status = DownloadStatus.FAILED + item.error_message = str(e) + self._trigger_callback('on_error', item) + + except Exception as e: + item.status = DownloadStatus.FAILED + item.error_message = f"Unexpected error: {e}" + self._trigger_callback('on_error', item) + + finally: + # Clean up flags + if download_id in self._stop_flags: + del self._stop_flags[download_id] + if download_id in self._pause_flags: + del self._pause_flags[download_id] + + self._trigger_callback('on_status_change', item) + + def register_callback(self, event: str, callback: Callable): + """Register a callback for download events.""" + if event in self.callbacks: + self.callbacks[event].append(callback) + + def _trigger_callback(self, event: str, item: DownloadItem): + """Trigger callbacks for an event.""" + for callback in self.callbacks.get(event, []): + try: + callback(item) + except Exception as e: + print(f"Callback error: {e}") + + def get_all_downloads(self) -> List[DownloadItem]: + """Get all download items.""" + return list(self.downloads.values()) + + def clear_completed(self): + """Clear completed downloads from the list.""" + to_remove = [] + for download_id, item in self.downloads.items(): + if item.status in [DownloadStatus.COMPLETED, DownloadStatus.CANCELLED]: + to_remove.append(download_id) + + for download_id in to_remove: + del self.downloads[download_id] + + +class DownloadManagerTab: + """Download Manager GUI tab.""" + + def __init__(self, parent: ttk.Frame, download_manager: DownloadManager): + self.parent = parent + self.manager = download_manager + self.item_widgets: Dict[str, Dict[str, Any]] = {} + + # Register callbacks + self.manager.register_callback('on_progress', self._on_progress) + self.manager.register_callback('on_status_change', self._on_status_change) + self.manager.register_callback('on_complete', self._on_complete) + self.manager.register_callback('on_error', self._on_error) + + self._build_ui() + + # Start update timer + self._update_display() + + def _build_ui(self): + """Build the download manager UI.""" + # Top controls + controls_frame = ttk.Frame(self.parent) + controls_frame.pack(fill=tk.X, padx=5, pady=5) + + ttk.Button(controls_frame, text="Clear Completed", + command=self._clear_completed).pack(side=tk.LEFT, padx=5) + ttk.Button(controls_frame, text="Pause All", + command=self._pause_all).pack(side=tk.LEFT, padx=5) + ttk.Button(controls_frame, text="Resume All", + command=self._resume_all).pack(side=tk.LEFT, padx=5) + + # Status summary + self.status_label = ttk.Label(controls_frame, text="Downloads: 0 active, 0 queued") + self.status_label.pack(side=tk.RIGHT, padx=5) + + # Downloads list with scrollbar + list_frame = ttk.Frame(self.parent) + list_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) + + # Canvas for scrollable content + self.canvas = tk.Canvas(list_frame, bg='white') + scrollbar = ttk.Scrollbar(list_frame, orient="vertical", command=self.canvas.yview) + self.scrollable_frame = ttk.Frame(self.canvas) + + self.scrollable_frame.bind( + "", + lambda e: self.canvas.configure(scrollregion=self.canvas.bbox("all")) + ) + + self.canvas.create_window((0, 0), window=self.scrollable_frame, anchor="nw") + self.canvas.configure(yscrollcommand=scrollbar.set) + + self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + def add_download_widget(self, item: DownloadItem): + """Add a download widget to the display.""" + if item.id in self.item_widgets: + return + + # Create frame for this download + frame = ttk.LabelFrame(self.scrollable_frame, text=f"{item.repo_id} / {item.filename}") + frame.pack(fill=tk.X, padx=5, pady=5) + + # Info row + info_frame = ttk.Frame(frame) + info_frame.pack(fill=tk.X, padx=5, pady=5) + + status_label = ttk.Label(info_frame, text=f"Status: {item.status.value}") + status_label.pack(side=tk.LEFT, padx=5) + + size_label = ttk.Label(info_frame, text="Size: -") + size_label.pack(side=tk.LEFT, padx=5) + + speed_label = ttk.Label(info_frame, text="Speed: -") + speed_label.pack(side=tk.LEFT, padx=5) + + eta_label = ttk.Label(info_frame, text="ETA: -") + eta_label.pack(side=tk.LEFT, padx=5) + + # Progress bar + progress_bar = ttk.Progressbar(frame, length=400, mode='determinate') + progress_bar.pack(fill=tk.X, padx=5, pady=5) + + # Error label (hidden by default) + error_label = ttk.Label(frame, text="", foreground="red") + + # Control buttons + controls_frame = ttk.Frame(frame) + controls_frame.pack(fill=tk.X, padx=5, pady=5) + + pause_btn = ttk.Button(controls_frame, text="Pause", width=10, + command=lambda: self.manager.pause_download(item.id)) + pause_btn.pack(side=tk.LEFT, padx=2) + + resume_btn = ttk.Button(controls_frame, text="Resume", width=10, + command=lambda: self.manager.resume_download(item.id)) + resume_btn.pack(side=tk.LEFT, padx=2) + + cancel_btn = ttk.Button(controls_frame, text="Cancel", width=10, + command=lambda: self.manager.cancel_download(item.id)) + cancel_btn.pack(side=tk.LEFT, padx=2) + + retry_btn = ttk.Button(controls_frame, text="Retry", width=10, + command=lambda: self.manager.retry_download(item.id)) + retry_btn.pack(side=tk.LEFT, padx=2) + + # Store widgets + self.item_widgets[item.id] = { + 'frame': frame, + 'status_label': status_label, + 'size_label': size_label, + 'speed_label': speed_label, + 'eta_label': eta_label, + 'progress_bar': progress_bar, + 'error_label': error_label, + 'pause_btn': pause_btn, + 'resume_btn': resume_btn, + 'cancel_btn': cancel_btn, + 'retry_btn': retry_btn + } + + # Initial update + self._update_download_widget(item) + + def _update_download_widget(self, item: DownloadItem): + """Update a download widget.""" + if item.id not in self.item_widgets: + self.add_download_widget(item) + return + + widgets = self.item_widgets[item.id] + + # Update status + widgets['status_label'].config(text=f"Status: {item.status.value}") + + # Update size + if item.total_size > 0: + size_text = f"Size: {self._format_size(item.downloaded_size)} / {self._format_size(item.total_size)}" + else: + size_text = f"Size: {self._format_size(item.downloaded_size)}" + widgets['size_label'].config(text=size_text) + + # Update speed + if item.speed > 0: + widgets['speed_label'].config(text=f"Speed: {self._format_size(item.speed)}/s") + else: + widgets['speed_label'].config(text="Speed: -") + + # Update ETA + if item.eta > 0: + eta_text = self._format_time(item.eta) + widgets['eta_label'].config(text=f"ETA: {eta_text}") + else: + widgets['eta_label'].config(text="ETA: -") + + # Update progress bar + widgets['progress_bar']['value'] = item.progress + + # Update error message + if item.error_message: + widgets['error_label'].config(text=f"Error: {item.error_message}") + widgets['error_label'].pack(fill=tk.X, padx=5, pady=2) + else: + widgets['error_label'].pack_forget() + + # Update button states + if item.status == DownloadStatus.DOWNLOADING: + widgets['pause_btn'].config(state="normal") + widgets['resume_btn'].config(state="disabled") + widgets['cancel_btn'].config(state="normal") + widgets['retry_btn'].config(state="disabled") + elif item.status == DownloadStatus.PAUSED: + widgets['pause_btn'].config(state="disabled") + widgets['resume_btn'].config(state="normal") + widgets['cancel_btn'].config(state="normal") + widgets['retry_btn'].config(state="disabled") + elif item.status in [DownloadStatus.FAILED, DownloadStatus.AUTH_REQUIRED]: + widgets['pause_btn'].config(state="disabled") + widgets['resume_btn'].config(state="disabled") + widgets['cancel_btn'].config(state="disabled") + widgets['retry_btn'].config(state="normal") + elif item.status == DownloadStatus.COMPLETED: + widgets['pause_btn'].config(state="disabled") + widgets['resume_btn'].config(state="disabled") + widgets['cancel_btn'].config(state="disabled") + widgets['retry_btn'].config(state="disabled") + else: + widgets['pause_btn'].config(state="disabled") + widgets['resume_btn'].config(state="disabled") + widgets['cancel_btn'].config(state="normal") + widgets['retry_btn'].config(state="disabled") + + def _format_size(self, bytes_size: float) -> str: + """Format bytes to human readable size.""" + for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + if bytes_size < 1024.0: + return f"{bytes_size:.1f} {unit}" + bytes_size /= 1024.0 + return f"{bytes_size:.1f} PB" + + def _format_time(self, seconds: int) -> str: + """Format seconds to human readable time.""" + if seconds < 60: + return f"{seconds}s" + elif seconds < 3600: + minutes = seconds // 60 + secs = seconds % 60 + return f"{minutes}m {secs}s" + else: + hours = seconds // 3600 + minutes = (seconds % 3600) // 60 + return f"{hours}h {minutes}m" + + def _on_progress(self, item: DownloadItem): + """Handle progress update.""" + self.parent.after(0, lambda: self._update_download_widget(item)) + + def _on_status_change(self, item: DownloadItem): + """Handle status change.""" + self.parent.after(0, lambda: self._update_download_widget(item)) + + def _on_complete(self, item: DownloadItem): + """Handle download completion.""" + self.parent.after(0, lambda: self._update_download_widget(item)) + + def _on_error(self, item: DownloadItem): + """Handle download error.""" + self.parent.after(0, lambda: self._update_download_widget(item)) + + # Show error notification for auth issues + if item.status == DownloadStatus.AUTH_REQUIRED: + self.parent.after(0, lambda: messagebox.showerror( + "Authentication Required", + f"Authentication failed for {item.filename}.\n" + f"Please check your API key in Settings." + )) + + def _clear_completed(self): + """Clear completed downloads.""" + # Remove widgets for completed downloads + for download_id in list(self.item_widgets.keys()): + item = self.manager.downloads.get(download_id) + if item and item.status in [DownloadStatus.COMPLETED, DownloadStatus.CANCELLED]: + self.item_widgets[download_id]['frame'].destroy() + del self.item_widgets[download_id] + + # Clear from manager + self.manager.clear_completed() + + def _pause_all(self): + """Pause all active downloads.""" + for item in self.manager.get_all_downloads(): + if item.status == DownloadStatus.DOWNLOADING: + self.manager.pause_download(item.id) + + def _resume_all(self): + """Resume all paused downloads.""" + for item in self.manager.get_all_downloads(): + if item.status == DownloadStatus.PAUSED: + self.manager.resume_download(item.id) + + def _update_display(self): + """Update the display periodically.""" + # Update status summary + all_downloads = self.manager.get_all_downloads() + active = sum(1 for d in all_downloads if d.status == DownloadStatus.DOWNLOADING) + queued = sum(1 for d in all_downloads if d.status == DownloadStatus.QUEUED) + self.status_label.config(text=f"Downloads: {active} active, {queued} queued") + + # Add any new downloads + for item in all_downloads: + if item.id not in self.item_widgets: + self.add_download_item(item) + + # Schedule next update + self.parent.after(500, self._update_display) \ No newline at end of file diff --git a/download_manager_fixed.py b/download_manager_fixed.py new file mode 100644 index 0000000..79d3542 --- /dev/null +++ b/download_manager_fixed.py @@ -0,0 +1,734 @@ +import os +import requests +import threading +import time +import tkinter as tk +from tkinter import ttk, messagebox +from typing import Optional, Dict, Any, List, Callable +from dataclasses import dataclass +from enum import Enum +import json +from datetime import datetime +import queue + + +class DownloadStatus(Enum): + """Download status enumeration.""" + QUEUED = "Queued" + DOWNLOADING = "Downloading" + PAUSED = "Paused" + COMPLETED = "Completed" + FAILED = "Failed" + CANCELLED = "Cancelled" + AUTH_REQUIRED = "Auth Required" + + +@dataclass +class DownloadItem: + """Represents a download item.""" + id: str + repo_id: str + filename: str + url: str + save_path: str + total_size: int = 0 + downloaded_size: int = 0 + status: DownloadStatus = DownloadStatus.QUEUED + error_message: str = "" + start_time: Optional[float] = None + end_time: Optional[float] = None + speed: float = 0.0 + eta: int = 0 + headers: Dict[str, str] = None + resume_position: int = 0 + + def __post_init__(self): + if self.headers is None: + self.headers = {} + + @property + def progress(self) -> float: + """Calculate download progress percentage.""" + if self.total_size > 0: + return (self.downloaded_size / self.total_size) * 100 + return 0.0 + + @property + def is_resumable(self) -> bool: + """Check if download can be resumed.""" + return self.status in [DownloadStatus.PAUSED, DownloadStatus.FAILED] + + +class DownloadManager: + """Manages multiple downloads with pause/resume support.""" + + def __init__(self, max_concurrent: int = 3): + self.downloads: Dict[str, DownloadItem] = {} + self.download_queue: queue.Queue = queue.Queue() + self.active_downloads: Dict[str, threading.Thread] = {} + self.max_concurrent = max_concurrent + self.callbacks: Dict[str, List[Callable]] = { + 'on_progress': [], + 'on_status_change': [], + 'on_complete': [], + 'on_error': [], + 'on_remove': [] + } + self._stop_flags: Dict[str, threading.Event] = {} + self._pause_flags: Dict[str, threading.Event] = {} + self._worker_thread = threading.Thread(target=self._process_queue, daemon=True) + self._worker_thread.start() + + def add_download(self, repo_id: str, filename: str, url: str, save_path: str, + headers: Optional[Dict[str, str]] = None) -> str: + """Add a new download to the queue.""" + download_id = f"{repo_id}_{filename}_{int(time.time())}" + + # Create download item + item = DownloadItem( + id=download_id, + repo_id=repo_id, + filename=filename, + url=url, + save_path=save_path, + headers=headers or {} + ) + + self.downloads[download_id] = item + self.download_queue.put(download_id) + self._trigger_callback('on_status_change', item) + + return download_id + + def pause_download(self, download_id: str): + """Pause a download.""" + if download_id in self._pause_flags: + self._pause_flags[download_id].set() + if download_id in self.downloads: + self.downloads[download_id].status = DownloadStatus.PAUSED + self._trigger_callback('on_status_change', self.downloads[download_id]) + + def resume_download(self, download_id: str): + """Resume a paused download.""" + item = self.downloads.get(download_id) + if item and item.is_resumable: + item.status = DownloadStatus.QUEUED + item.resume_position = item.downloaded_size + self.download_queue.put(download_id) + self._trigger_callback('on_status_change', item) + + def cancel_download(self, download_id: str): + """Cancel a download.""" + if download_id in self._stop_flags: + self._stop_flags[download_id].set() + + if download_id in self.downloads: + item = self.downloads[download_id] + item.status = DownloadStatus.CANCELLED + self._trigger_callback('on_status_change', item) + + # Remove partial file + if os.path.exists(item.save_path): + try: + os.remove(item.save_path) + except Exception: + pass + + # Auto-remove cancelled download after a short delay + def auto_remove(): + time.sleep(2) # Wait 2 seconds + self.remove_download(download_id) + + threading.Thread(target=auto_remove, daemon=True).start() + + def remove_download(self, download_id: str): + """Remove a specific download from the list.""" + if download_id in self.downloads: + # Cancel if still active + if download_id in self._stop_flags: + self._stop_flags[download_id].set() + + # Remove from downloads + item = self.downloads[download_id] + del self.downloads[download_id] + self._trigger_callback('on_remove', item) + + def retry_download(self, download_id: str): + """Retry a failed download.""" + item = self.downloads.get(download_id) + if item and item.status in [DownloadStatus.FAILED, DownloadStatus.AUTH_REQUIRED]: + item.status = DownloadStatus.QUEUED + item.downloaded_size = 0 + item.resume_position = 0 + item.error_message = "" + self.download_queue.put(download_id) + self._trigger_callback('on_status_change', item) + + def _process_queue(self): + """Process download queue.""" + while True: + # Check if we can start more downloads + if len(self.active_downloads) < self.max_concurrent: + try: + download_id = self.download_queue.get(timeout=1) + if download_id in self.downloads: + thread = threading.Thread( + target=self._download_file, + args=(download_id,), + daemon=True + ) + self.active_downloads[download_id] = thread + thread.start() + except queue.Empty: + pass + + # Clean up finished downloads + finished = [] + for download_id, thread in self.active_downloads.items(): + if not thread.is_alive(): + finished.append(download_id) + + for download_id in finished: + del self.active_downloads[download_id] + + time.sleep(0.5) + + def _download_file(self, download_id: str): + """Download a file with resume support.""" + item = self.downloads[download_id] + + # Create flags for this download + self._stop_flags[download_id] = threading.Event() + self._pause_flags[download_id] = threading.Event() + + try: + # Update status + item.status = DownloadStatus.DOWNLOADING + item.start_time = time.time() + self._trigger_callback('on_status_change', item) + + # Create directory if needed + os.makedirs(os.path.dirname(item.save_path), exist_ok=True) + + # Setup headers for resume + headers = item.headers.copy() + if item.resume_position > 0: + headers['Range'] = f'bytes={item.resume_position}-' + + # Make request + response = requests.get(item.url, headers=headers, stream=True, timeout=30) + + # Check for authentication issues + if response.status_code == 401 or response.status_code == 403: + item.status = DownloadStatus.AUTH_REQUIRED + item.error_message = f"Authentication failed: {response.status_code}" + self._trigger_callback('on_error', item) + return + + response.raise_for_status() + + # Get total size + if item.resume_position == 0: + item.total_size = int(response.headers.get('content-length', 0)) + + # Open file for writing (append if resuming) + mode = 'ab' if item.resume_position > 0 else 'wb' + + # Optimize chunk size based on file size and storage type + base_chunk_size = 1024 * 1024 # 1MB base chunk + if item.total_size > 100 * 1024 * 1024: # Files > 100MB + chunk_size = base_chunk_size * 4 # 4MB chunks + elif item.total_size > 10 * 1024 * 1024: # Files > 10MB + chunk_size = base_chunk_size * 2 # 2MB chunks + else: + chunk_size = base_chunk_size # 1MB chunks + + # Use buffered writing for better performance + buffer_size = chunk_size * 8 # 8x chunk size buffer + + with open(item.save_path, mode, buffering=buffer_size) as f: + last_update = time.time() + bytes_since_update = 0 + update_interval = 0.5 # Update UI every 0.5 seconds + + for chunk in response.iter_content(chunk_size=chunk_size): + # Check stop flag + if self._stop_flags[download_id].is_set(): + return + + # Check pause flag + if self._pause_flags[download_id].is_set(): + item.status = DownloadStatus.PAUSED + self._trigger_callback('on_status_change', item) + return + + if chunk: + f.write(chunk) + item.downloaded_size += len(chunk) + bytes_since_update += len(chunk) + + # Calculate speed and ETA (less frequent updates for performance) + current_time = time.time() + time_diff = current_time - last_update + + if time_diff >= update_interval: + item.speed = bytes_since_update / time_diff + + if item.speed > 0 and item.total_size > item.downloaded_size: + remaining = item.total_size - item.downloaded_size + item.eta = int(remaining / item.speed) + + self._trigger_callback('on_progress', item) + last_update = current_time + bytes_since_update = 0 + + # Force flush for USB drives + f.flush() + os.fsync(f.fileno()) + + # Download completed + item.status = DownloadStatus.COMPLETED + item.end_time = time.time() + self._trigger_callback('on_complete', item) + + except requests.exceptions.RequestException as e: + item.status = DownloadStatus.FAILED + item.error_message = str(e) + self._trigger_callback('on_error', item) + + except Exception as e: + item.status = DownloadStatus.FAILED + item.error_message = f"Unexpected error: {e}" + self._trigger_callback('on_error', item) + + finally: + # Clean up flags + if download_id in self._stop_flags: + del self._stop_flags[download_id] + if download_id in self._pause_flags: + del self._pause_flags[download_id] + + self._trigger_callback('on_status_change', item) + + def register_callback(self, event: str, callback: Callable): + """Register a callback for download events.""" + if event in self.callbacks: + self.callbacks[event].append(callback) + + def _trigger_callback(self, event: str, item: DownloadItem): + """Trigger callbacks for an event.""" + for callback in self.callbacks.get(event, []): + try: + callback(item) + except Exception as e: + print(f"Callback error: {e}") + + def get_all_downloads(self) -> List[DownloadItem]: + """Get all download items.""" + return list(self.downloads.values()) + + def clear_completed(self): + """Clear completed downloads from the list.""" + to_remove = [] + for download_id, item in self.downloads.items(): + if item.status in [DownloadStatus.COMPLETED, DownloadStatus.CANCELLED]: + to_remove.append(download_id) + + for download_id in to_remove: + self.remove_download(download_id) + + +class DownloadManagerTab: + """Download Manager GUI tab with improved real-time updates.""" + + def __init__(self, parent: ttk.Frame, download_manager: DownloadManager): + self.parent = parent + self.manager = download_manager + self.item_widgets: Dict[str, Dict[str, Any]] = {} + + # Register callbacks + self.manager.register_callback('on_progress', self._on_progress) + self.manager.register_callback('on_status_change', self._on_status_change) + self.manager.register_callback('on_complete', self._on_complete) + self.manager.register_callback('on_error', self._on_error) + self.manager.register_callback('on_remove', self._on_remove) + + self._build_ui() + + # Start update timer with shorter interval for real-time updates + self._update_display() + + def _build_ui(self): + """Build the download manager UI.""" + # Top controls with better layout + controls_frame = ttk.Frame(self.parent) + controls_frame.pack(fill=tk.X, padx=10, pady=5) + + # Left side controls + left_controls = ttk.Frame(controls_frame) + left_controls.pack(side=tk.LEFT) + + ttk.Button(left_controls, text="Clear Completed", + command=self._clear_completed).pack(side=tk.LEFT, padx=2) + ttk.Button(left_controls, text="Pause All", + command=self._pause_all).pack(side=tk.LEFT, padx=2) + ttk.Button(left_controls, text="Resume All", + command=self._resume_all).pack(side=tk.LEFT, padx=2) + ttk.Button(left_controls, text="Refresh", + command=self._refresh_list).pack(side=tk.LEFT, padx=2) + ttk.Button(left_controls, text="Remove Selected", + command=self._remove_selected).pack(side=tk.LEFT, padx=2) + + # Right side status + right_controls = ttk.Frame(controls_frame) + right_controls.pack(side=tk.RIGHT) + + self.status_label = ttk.Label(right_controls, text="Downloads: 0 active, 0 queued") + self.status_label.pack(side=tk.RIGHT) + + # Main downloads area using Treeview for better layout + main_frame = ttk.Frame(self.parent) + main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=5) + + # Create treeview for downloads + columns = ("repo", "file", "status", "progress", "size", "speed", "eta") + self.tree = ttk.Treeview(main_frame, columns=columns, show="headings", height=15) + + # Configure columns + self.tree.heading("repo", text="Repository") + self.tree.heading("file", text="File") + self.tree.heading("status", text="Status") + self.tree.heading("progress", text="Progress") + self.tree.heading("size", text="Size") + self.tree.heading("speed", text="Speed") + self.tree.heading("eta", text="ETA") + + self.tree.column("repo", width=200) + self.tree.column("file", width=250) + self.tree.column("status", width=100) + self.tree.column("progress", width=100) + self.tree.column("size", width=100) + self.tree.column("speed", width=100) + self.tree.column("eta", width=80) + + # Scrollbars for treeview + v_scrollbar = ttk.Scrollbar(main_frame, orient="vertical", command=self.tree.yview) + h_scrollbar = ttk.Scrollbar(main_frame, orient="horizontal", command=self.tree.xview) + self.tree.configure(yscrollcommand=v_scrollbar.set, xscrollcommand=h_scrollbar.set) + + # Grid layout + self.tree.grid(row=0, column=0, sticky="nsew") + v_scrollbar.grid(row=0, column=1, sticky="ns") + h_scrollbar.grid(row=1, column=0, sticky="ew") + + main_frame.grid_rowconfigure(0, weight=1) + main_frame.grid_columnconfigure(0, weight=1) + + # Context menu for downloads + self.context_menu = tk.Menu(self.tree, tearoff=0) + self.context_menu.add_command(label="Pause", command=self._context_pause) + self.context_menu.add_command(label="Resume", command=self._context_resume) + self.context_menu.add_command(label="Cancel", command=self._context_cancel) + self.context_menu.add_command(label="Retry", command=self._context_retry) + self.context_menu.add_separator() + self.context_menu.add_command(label="Remove", command=self._context_remove) + self.context_menu.add_command(label="Open Folder", command=self._context_open_folder) + + self.tree.bind("", self._show_context_menu) + + # Details frame for selected download + details_frame = ttk.LabelFrame(self.parent, text="Download Details", padding=5) + details_frame.pack(fill=tk.X, padx=10, pady=5) + + self.details_text = tk.Text(details_frame, height=3, wrap=tk.WORD) + self.details_text.pack(fill=tk.X) + + self.tree.bind("<>", self._on_selection_change) + + def add_download_item(self, item: DownloadItem): + """Add a download item to the treeview.""" + if item.id in self.item_widgets: + return + + # Insert into treeview + tree_id = self.tree.insert("", tk.END, values=( + item.repo_id, + item.filename, + item.status.value, + f"{item.progress:.1f}%", + self._format_size(item.total_size) if item.total_size > 0 else "-", + "-", + "-" + )) + + # Store mapping + self.item_widgets[item.id] = { + 'tree_id': tree_id, + 'item': item + } + + # Initial update + self._update_download_item(item) + + def _update_download_item(self, item: DownloadItem): + """Update a download item in the treeview.""" + if item.id not in self.item_widgets: + self.add_download_item(item) + return + + tree_id = self.item_widgets[item.id]['tree_id'] + + # Check if tree item still exists + if not self.tree.exists(tree_id): + # Tree item was deleted, remove from our tracking + del self.item_widgets[item.id] + return + + # Update size display + if item.total_size > 0: + size_text = f"{self._format_size(item.downloaded_size)} / {self._format_size(item.total_size)}" + else: + size_text = self._format_size(item.downloaded_size) + + # Update speed display + speed_text = f"{self._format_size(item.speed)}/s" if item.speed > 0 else "-" + + # Update ETA display + eta_text = self._format_time(item.eta) if item.eta > 0 else "-" + + try: + # Update treeview row + self.tree.item(tree_id, values=( + item.repo_id, + item.filename, + item.status.value, + f"{item.progress:.1f}%", + size_text, + speed_text, + eta_text + )) + + # Update row color based on status + if item.status == DownloadStatus.COMPLETED: + self.tree.item(tree_id, tags=("completed",)) + elif item.status == DownloadStatus.FAILED: + self.tree.item(tree_id, tags=("failed",)) + elif item.status == DownloadStatus.DOWNLOADING: + self.tree.item(tree_id, tags=("downloading",)) + else: + self.tree.item(tree_id, tags=()) + + # Configure tag colors + self.tree.tag_configure("completed", background="#d4edda") + self.tree.tag_configure("failed", background="#f8d7da") + self.tree.tag_configure("downloading", background="#d1ecf1") + + except tk.TclError: + # Tree item no longer exists + if item.id in self.item_widgets: + del self.item_widgets[item.id] + + def _format_size(self, bytes_size: float) -> str: + """Format bytes to human readable size.""" + for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + if bytes_size < 1024.0: + return f"{bytes_size:.1f} {unit}" + bytes_size /= 1024.0 + return f"{bytes_size:.1f} PB" + + def _format_time(self, seconds: int) -> str: + """Format seconds to human readable time.""" + if seconds < 60: + return f"{seconds}s" + elif seconds < 3600: + minutes = seconds // 60 + secs = seconds % 60 + return f"{minutes}m {secs}s" + else: + hours = seconds // 3600 + minutes = (seconds % 3600) // 60 + return f"{hours}h {minutes}m" + + def _show_context_menu(self, event): + """Show context menu for downloads.""" + item = self.tree.selection()[0] if self.tree.selection() else None + if item: + self.context_menu.post(event.x_root, event.y_root) + + def _get_selected_download_id(self): + """Get the download ID of the selected item.""" + selection = self.tree.selection() + if not selection: + return None + + tree_id = selection[0] + for download_id, data in self.item_widgets.items(): + if data['tree_id'] == tree_id: + return download_id + return None + + def _context_pause(self): + download_id = self._get_selected_download_id() + if download_id: + self.manager.pause_download(download_id) + + def _context_resume(self): + download_id = self._get_selected_download_id() + if download_id: + self.manager.resume_download(download_id) + + def _context_cancel(self): + download_id = self._get_selected_download_id() + if download_id: + self.manager.cancel_download(download_id) + + def _context_retry(self): + download_id = self._get_selected_download_id() + if download_id: + self.manager.retry_download(download_id) + + def _context_remove(self): + download_id = self._get_selected_download_id() + if download_id: + self.manager.remove_download(download_id) + + def _context_open_folder(self): + download_id = self._get_selected_download_id() + if download_id and download_id in self.manager.downloads: + item = self.manager.downloads[download_id] + folder = os.path.dirname(item.save_path) + if os.path.exists(folder): + import subprocess + import platform + if platform.system() == "Windows": + subprocess.run(["explorer", folder]) + elif platform.system() == "Darwin": + subprocess.run(["open", folder]) + else: + subprocess.run(["xdg-open", folder]) + + def _refresh_list(self): + """Refresh the download list display.""" + # Clear all items + for item_id in self.tree.get_children(): + self.tree.delete(item_id) + + # Clear widget mapping + self.item_widgets.clear() + + # Re-add all downloads + for item in self.manager.get_all_downloads(): + self.add_download_item(item) + + def _remove_selected(self): + """Remove the selected download.""" + download_id = self._get_selected_download_id() + if download_id: + self.manager.remove_download(download_id) + + def _on_selection_change(self, event): + """Handle selection change in treeview.""" + download_id = self._get_selected_download_id() + if download_id and download_id in self.manager.downloads: + item = self.manager.downloads[download_id] + details = f"File: {item.filename}\\n" + details += f"Repository: {item.repo_id}\\n" + details += f"Save Path: {item.save_path}" + + if item.error_message: + details += f"\\nError: {item.error_message}" + + self.details_text.delete(1.0, tk.END) + self.details_text.insert(1.0, details) + + def _on_progress(self, item: DownloadItem): + """Handle progress update.""" + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: self._update_download_item(item)) + + def _on_status_change(self, item: DownloadItem): + """Handle status change.""" + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: self._update_download_item(item)) + + def _on_complete(self, item: DownloadItem): + """Handle download completion.""" + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: self._update_download_item(item)) + + def _on_error(self, item: DownloadItem): + """Handle download error.""" + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: self._update_download_item(item)) + + # Show error notification for auth issues + if item.status == DownloadStatus.AUTH_REQUIRED: + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: messagebox.showerror( + "Authentication Required", + f"Authentication failed for {item.filename}.\\n" + f"Please check your API key in Settings." + )) + + def _on_remove(self, item: DownloadItem): + """Handle download removal.""" + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: self._remove_download_from_tree(item.id)) + + def _remove_download_from_tree(self, download_id: str): + """Remove download from treeview.""" + if download_id in self.item_widgets: + try: + tree_id = self.item_widgets[download_id]['tree_id'] + self.tree.delete(tree_id) + del self.item_widgets[download_id] + except (tk.TclError, KeyError): + # Item already removed or doesn't exist + if download_id in self.item_widgets: + del self.item_widgets[download_id] + + def _clear_completed(self): + """Clear completed downloads.""" + self.manager.clear_completed() + + def _pause_all(self): + """Pause all active downloads.""" + for item in self.manager.get_all_downloads(): + if item.status == DownloadStatus.DOWNLOADING: + self.manager.pause_download(item.id) + + def _resume_all(self): + """Resume all paused downloads.""" + for item in self.manager.get_all_downloads(): + if item.status == DownloadStatus.PAUSED: + self.manager.resume_download(item.id) + + def _update_display(self): + """Update the display periodically with real-time updates.""" + try: + # Update status summary + all_downloads = self.manager.get_all_downloads() + active = sum(1 for d in all_downloads if d.status == DownloadStatus.DOWNLOADING) + queued = sum(1 for d in all_downloads if d.status == DownloadStatus.QUEUED) + completed = sum(1 for d in all_downloads if d.status == DownloadStatus.COMPLETED) + failed = sum(1 for d in all_downloads if d.status in [DownloadStatus.FAILED, DownloadStatus.CANCELLED]) + + status_text = f"Downloads: {active} active, {queued} queued, {completed} completed, {failed} failed" + self.status_label.config(text=status_text) + + # Add any new downloads + for item in all_downloads: + if item.id not in self.item_widgets: + self.add_download_item(item) + + # Force update all items for real-time progress + for item in all_downloads: + if item.id in self.item_widgets: + self._update_download_item(item) + + except Exception as e: + print(f"Error updating download display: {e}") + finally: + # Schedule next update with shorter interval for real-time feel + if hasattr(self.parent, 'after'): + self.parent.after(200, self._update_display) # Update every 200ms for smooth progress \ No newline at end of file diff --git a/engines/llm_uci.py b/engines/llm_uci.py new file mode 100644 index 0000000..a1df62a --- /dev/null +++ b/engines/llm_uci.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +""" +Minimal UCI chess engine that delegates move selection to an LLM (optional). +Works with GUIs like Arena/Cute Chess. Acts as "player 2" (black) when the GUI pairs it that way. + +This is integrated with DarkHal 2.0's LLM system. +""" + +import os +import sys +import time +import json +import random +import requests +import threading +from pathlib import Path + +try: + import chess +except ImportError: + print("Please install python-chess: pip install python-chess==1.999") + sys.exit(1) + +ENGINE_NAME = "DarkHal-Chess-Engine" +ENGINE_AUTHOR = "Setec Labs" + +class DarkHalChessEngine: + """UCI Chess Engine integrated with DarkHal 2.0""" + + def __init__(self): + self.board = chess.Board() + self.history_san = [] + self.settings_file = Path("../settings.json") + self.llm_settings = self._load_llm_settings() + + def _load_llm_settings(self): + """Load LLM settings from DarkHal 2.0 settings file.""" + try: + if self.settings_file.exists(): + with open(self.settings_file, 'r') as f: + settings = json.load(f) + return { + 'model_path': settings.get('paths', {}).get('last_model_path', ''), + 'temperature': 0.3, + 'max_tokens': 50, + 'n_ctx': settings.get('model_settings', {}).get('default_n_ctx', 4096), + 'n_gpu_layers': settings.get('model_settings', {}).get('default_n_gpu_layers', 0) + } + except Exception: + pass + + return { + 'model_path': '', + 'temperature': 0.3, + 'max_tokens': 50, + 'n_ctx': 4096, + 'n_gpu_layers': 0 + } + + def _query_darkhal_llm(self, prompt: str) -> str: + """Query DarkHal's loaded LLM model for a chess move.""" + try: + # Try to import llama_cpp from the main project + sys.path.append('..') + from llama_cpp import Llama + + # Check if we have a model path + model_path = self.llm_settings.get('model_path', '') + if not model_path or not os.path.exists(model_path): + return None + + # Create Llama instance (simplified - in real implementation this would be cached) + llm = Llama( + model_path=model_path, + n_ctx=self.llm_settings['n_ctx'], + n_gpu_layers=self.llm_settings['n_gpu_layers'], + verbose=False + ) + + # Generate response + response = llm( + prompt, + max_tokens=self.llm_settings['max_tokens'], + temperature=self.llm_settings['temperature'], + stop=["\n", ".", ",", " "], + echo=False + ) + + return response['choices'][0]['text'].strip() + + except Exception as e: + # If LLM fails, return None to fall back to random moves + return None + + def _query_ollama(self, prompt: str) -> str: + """Query an Ollama server for a chess move.""" + model = os.getenv("OLLAMA_MODEL") + host = os.getenv("OLLAMA_HOST", "http://localhost:11434") + if not model: + return None + try: + r = requests.post( + f"{host}/api/generate", + json={"model": model, "prompt": prompt, "stream": False, "options": {"temperature": 0.3}}, + timeout=30, + ) + r.raise_for_status() + data = r.json() + return data.get("response", "") + except Exception: + return None + + def _query_llamacpp_server(self, prompt: str) -> str: + """Query a llama.cpp-compatible server.""" + url = os.getenv("LLAMACPP_URL") # e.g., "http://localhost:8080" + if not url: + return None + try: + r = requests.post( + f"{url}/completion", + json={"prompt": prompt, "n_predict": 64, "temperature": 0.3, "stop": ["\n"]}, + timeout=30, + ) + r.raise_for_status() + data = r.json() + return (data.get("content") or data.get("result") or "").strip() + except Exception: + return None + + def _ask_llm_for_move(self, fen: str, legal_uci: list, history_san: list) -> str: + """Ask the LLM to choose a move from legal_uci.""" + legal_str = ", ".join(legal_uci[:20]) # Limit to first 20 moves to avoid token limit + hist_str = " ".join(history_san[-10:]) if history_san else "game start" + + # Create a chess-focused prompt + prompt = f"""You are a chess engine playing as {'Black' if self.board.turn == chess.BLACK else 'White'}. + +Current position (FEN): {fen} +Recent moves: {hist_str} +Legal moves available: {legal_str} + +Choose the BEST move from the legal moves list. Consider: +- Piece safety and development +- Control of center squares +- King safety +- Tactical opportunities + +Respond with ONLY the move in UCI format (e.g., "e2e4" or "g1f3"):""" + + # Try different LLM sources + response = None + + # 1. Try DarkHal's internal LLM first + response = self._query_darkhal_llm(prompt) + + # 2. Fall back to Ollama + if not response: + response = self._query_ollama(prompt) + + # 3. Fall back to llama.cpp server + if not response: + response = self._query_llamacpp_server(prompt) + + if not response: + return None + + # Parse the response to extract UCI move + response = response.strip().lower() + + # Look for exact match first + for move in legal_uci: + if move in response: + return move + + # Try to extract move-like patterns + import re + move_pattern = r'[a-h][1-8][a-h][1-8][qrnb]?' + matches = re.findall(move_pattern, response) + + for match in matches: + if match in legal_uci: + return match + + return None + + def loop(self): + """Main UCI communication loop.""" + while True: + try: + line = sys.stdin.readline() + if not line: + break + line = line.strip() + + if line == "uci": + self.cmd_uci() + elif line == "isready": + self.cmd_isready() + elif line == "ucinewgame": + self.cmd_ucinewgame() + elif line.startswith("position"): + self.cmd_position(line) + elif line.startswith("go"): + self.cmd_go(line) + elif line == "quit": + break + # Ignore other commands: "stop", "ponderhit", "setoption", etc. + + except (EOFError, KeyboardInterrupt): + break + + def cmd_uci(self): + """Handle UCI identification command.""" + print(f"id name {ENGINE_NAME}") + print(f"id author {ENGINE_AUTHOR}") + # Engine options + print("option name Skill Level type spin default 5 min 0 max 10") + print("option name Use LLM type check default true") + print("uciok") + sys.stdout.flush() + + def cmd_isready(self): + """Handle UCI ready check.""" + print("readyok") + sys.stdout.flush() + + def cmd_ucinewgame(self): + """Handle new game command.""" + self.board = chess.Board() + self.history_san.clear() + + def cmd_position(self, line: str): + """Handle position setup command.""" + parts = line.split() + + if "startpos" in parts: + self.board = chess.Board() + moves_index = parts.index("startpos") + 1 + elif "fen" in parts: + fen_index = parts.index("fen") + 1 + fen = " ".join(parts[fen_index:fen_index + 6]) + self.board = chess.Board(fen) + moves_index = fen_index + 6 + else: + return + + # Apply moves if present + if moves_index < len(parts) and parts[moves_index] == "moves": + self.history_san.clear() + for mv in parts[moves_index + 1:]: + try: + move = self.board.parse_uci(mv) + self.history_san.append(self.board.san(move)) + self.board.push(move) + except Exception: + # Ignore illegal moves from GUI (shouldn't happen) + pass + + def cmd_go(self, line: str): + """Handle go (search for best move) command.""" + # Get legal moves + legal_moves = list(self.board.legal_moves) + legal_uci = [move.uci() for move in legal_moves] + + if not legal_uci: + print("bestmove 0000") + sys.stdout.flush() + return + + # Ask LLM for move + fen = self.board.fen() + chosen_move = self._ask_llm_for_move(fen, legal_uci, self.history_san) + + # Fall back to strategic random if LLM fails + if chosen_move not in legal_uci: + chosen_move = self._choose_fallback_move(legal_moves) + + # Validate and make move + try: + move = self.board.parse_uci(chosen_move) + if move in legal_moves: + # Optional: Print thinking info + print(f"info depth 1 score cp 0 pv {chosen_move}") + print(f"bestmove {chosen_move}") + else: + # Safety fallback + print(f"bestmove {random.choice(legal_uci)}") + except Exception: + print(f"bestmove {random.choice(legal_uci)}") + + sys.stdout.flush() + + def _choose_fallback_move(self, legal_moves): + """Choose a strategic move when LLM fails.""" + # Simple heuristics for move selection + scored_moves = [] + + for move in legal_moves: + score = 0 + + # Prefer captures + if self.board.is_capture(move): + score += 10 + + # Prefer central squares + to_square = move.to_square + file = chess.square_file(to_square) + rank = chess.square_rank(to_square) + if 2 <= file <= 5 and 2 <= rank <= 5: # Central squares + score += 3 + + # Prefer piece development + piece = self.board.piece_at(move.from_square) + if piece and piece.piece_type in [chess.KNIGHT, chess.BISHOP]: + if chess.square_rank(move.from_square) in [0, 7]: # From back rank + score += 5 + + # Avoid moving king early + if piece and piece.piece_type == chess.KING: + score -= 5 + + scored_moves.append((move, score)) + + # Sort by score and add some randomness + scored_moves.sort(key=lambda x: x[1] + random.random(), reverse=True) + return scored_moves[0][0].uci() + + +def main(): + """Main entry point.""" + # Ensure unbuffered output + engine = DarkHalChessEngine() + engine.loop() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/finetune_tab.py b/finetune_tab.py new file mode 100644 index 0000000..0daadff --- /dev/null +++ b/finetune_tab.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +""" +Fine Tune Tab for DarkHal 2.0 + +Model fine-tuning interface for training and customizing AI models. +""" + +import tkinter as tk +from tkinter import ttk, messagebox, scrolledtext, filedialog +import os +import sys +import json +from pathlib import Path +from datetime import datetime +from typing import Optional, Dict, Any, List + + +class FineTuneTab: + """Fine-tuning interface for training and customizing AI models.""" + + def __init__(self, parent: ttk.Frame, settings_manager): + self.parent = parent + self.settings = settings_manager + self.current_model = None + self.training_in_progress = False + + # Create main frame + self.main_frame = ttk.Frame(parent) + self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Create fine-tuning interface + self._create_finetune_interface() + + def _create_finetune_interface(self): + """Create the main fine-tuning interface.""" + + # Model Selection Frame + model_frame = ttk.LabelFrame(self.main_frame, text="Model Selection", padding=10) + model_frame.pack(fill=tk.X, pady=(0, 10)) + + # Base model selection + ttk.Label(model_frame, text="Base Model:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.base_model_var = tk.StringVar() + self.base_model_entry = ttk.Entry(model_frame, textvariable=self.base_model_var, width=50) + self.base_model_entry.grid(row=0, column=1, padx=5) + ttk.Button(model_frame, text="Browse", command=self._browse_base_model).grid(row=0, column=2, padx=5) + + # Output model name + ttk.Label(model_frame, text="Output Model Name:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.output_model_var = tk.StringVar(value="my-finetuned-model") + ttk.Entry(model_frame, textvariable=self.output_model_var, width=50).grid(row=1, column=1, padx=5) + + # Training Data Frame + data_frame = ttk.LabelFrame(self.main_frame, text="Training Data", padding=10) + data_frame.pack(fill=tk.X, pady=(0, 10)) + + # Dataset file + ttk.Label(data_frame, text="Dataset File:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.dataset_var = tk.StringVar() + ttk.Entry(data_frame, textvariable=self.dataset_var, width=50).grid(row=0, column=1, padx=5) + ttk.Button(data_frame, text="Browse", command=self._browse_dataset).grid(row=0, column=2, padx=5) + + # Dataset format + ttk.Label(data_frame, text="Dataset Format:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.format_var = tk.StringVar(value="alpaca") + format_combo = ttk.Combobox(data_frame, textvariable=self.format_var, + values=["alpaca", "sharegpt", "completion", "chat", "custom"], + state="readonly", width=20) + format_combo.grid(row=1, column=1, sticky=tk.W, padx=5) + + # Training split + ttk.Label(data_frame, text="Train/Val Split:").grid(row=2, column=0, sticky=tk.W, pady=5) + self.split_var = tk.StringVar(value="90/10") + ttk.Combobox(data_frame, textvariable=self.split_var, + values=["80/20", "90/10", "95/5", "100/0"], + state="readonly", width=20).grid(row=2, column=1, sticky=tk.W, padx=5) + + # Training Parameters Frame + params_frame = ttk.LabelFrame(self.main_frame, text="Training Parameters", padding=10) + params_frame.pack(fill=tk.X, pady=(0, 10)) + + # Create two columns for parameters + left_params = ttk.Frame(params_frame) + left_params.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + + right_params = ttk.Frame(params_frame) + right_params.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True) + + # Left column parameters + ttk.Label(left_params, text="Training Method:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.method_var = tk.StringVar(value="LoRA") + ttk.Combobox(left_params, textvariable=self.method_var, + values=["LoRA", "QLoRA", "Full Fine-tune", "PEFT"], + state="readonly", width=15).grid(row=0, column=1, padx=5) + + ttk.Label(left_params, text="Epochs:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.epochs_var = tk.StringVar(value="3") + ttk.Spinbox(left_params, from_=1, to=100, textvariable=self.epochs_var, width=15).grid(row=1, column=1, padx=5) + + ttk.Label(left_params, text="Batch Size:").grid(row=2, column=0, sticky=tk.W, pady=5) + self.batch_var = tk.StringVar(value="4") + ttk.Spinbox(left_params, from_=1, to=64, textvariable=self.batch_var, width=15).grid(row=2, column=1, padx=5) + + ttk.Label(left_params, text="Learning Rate:").grid(row=3, column=0, sticky=tk.W, pady=5) + self.lr_var = tk.StringVar(value="2e-4") + ttk.Entry(left_params, textvariable=self.lr_var, width=15).grid(row=3, column=1, padx=5) + + # Right column parameters + ttk.Label(right_params, text="LoRA Rank:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.lora_rank_var = tk.StringVar(value="8") + ttk.Spinbox(right_params, from_=4, to=128, textvariable=self.lora_rank_var, width=15).grid(row=0, column=1, padx=5) + + ttk.Label(right_params, text="LoRA Alpha:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.lora_alpha_var = tk.StringVar(value="16") + ttk.Spinbox(right_params, from_=8, to=256, textvariable=self.lora_alpha_var, width=15).grid(row=1, column=1, padx=5) + + ttk.Label(right_params, text="Max Length:").grid(row=2, column=0, sticky=tk.W, pady=5) + self.max_length_var = tk.StringVar(value="512") + ttk.Spinbox(right_params, from_=128, to=4096, increment=128, textvariable=self.max_length_var, width=15).grid(row=2, column=1, padx=5) + + ttk.Label(right_params, text="Warmup Steps:").grid(row=3, column=0, sticky=tk.W, pady=5) + self.warmup_var = tk.StringVar(value="100") + ttk.Spinbox(right_params, from_=0, to=1000, textvariable=self.warmup_var, width=15).grid(row=3, column=1, padx=5) + + # Hardware Settings Frame + hardware_frame = ttk.LabelFrame(self.main_frame, text="Hardware Settings", padding=10) + hardware_frame.pack(fill=tk.X, pady=(0, 10)) + + ttk.Label(hardware_frame, text="Device:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.device_var = tk.StringVar(value="cuda") + ttk.Combobox(hardware_frame, textvariable=self.device_var, + values=["cuda", "cpu", "mps"], + state="readonly", width=15).grid(row=0, column=1, padx=5) + + ttk.Label(hardware_frame, text="Mixed Precision:").grid(row=0, column=2, sticky=tk.W, padx=(20, 5)) + self.mixed_precision_var = tk.BooleanVar(value=True) + ttk.Checkbutton(hardware_frame, variable=self.mixed_precision_var).grid(row=0, column=3) + + ttk.Label(hardware_frame, text="Gradient Checkpointing:").grid(row=0, column=4, sticky=tk.W, padx=(20, 5)) + self.grad_checkpoint_var = tk.BooleanVar(value=True) + ttk.Checkbutton(hardware_frame, variable=self.grad_checkpoint_var).grid(row=0, column=5) + + # Control Buttons Frame + control_frame = ttk.Frame(self.main_frame) + control_frame.pack(fill=tk.X, pady=(0, 10)) + + ttk.Button(control_frame, text="Start Training", command=self._start_training, + style="Accent.TButton").pack(side=tk.LEFT, padx=5) + ttk.Button(control_frame, text="Stop Training", command=self._stop_training).pack(side=tk.LEFT, padx=5) + ttk.Button(control_frame, text="Save Config", command=self._save_config).pack(side=tk.LEFT, padx=5) + ttk.Button(control_frame, text="Load Config", command=self._load_config).pack(side=tk.LEFT, padx=5) + ttk.Button(control_frame, text="Validate Dataset", command=self._validate_dataset).pack(side=tk.LEFT, padx=5) + + # Progress Frame + progress_frame = ttk.LabelFrame(self.main_frame, text="Training Progress", padding=10) + progress_frame.pack(fill=tk.BOTH, expand=True) + + # Progress bar + self.progress_var = tk.DoubleVar() + self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, maximum=100) + self.progress_bar.pack(fill=tk.X, pady=(0, 10)) + + # Status label + self.status_label = ttk.Label(progress_frame, text="Ready to start training", foreground="green") + self.status_label.pack(anchor=tk.W, pady=(0, 10)) + + # Training log + ttk.Label(progress_frame, text="Training Log:", font=("Arial", 10, "bold")).pack(anchor=tk.W) + self.log_text = scrolledtext.ScrolledText(progress_frame, height=10, width=80, state=tk.DISABLED) + self.log_text.pack(fill=tk.BOTH, expand=True) + + def _browse_base_model(self): + """Browse for base model file.""" + filename = filedialog.askopenfilename( + title="Select Base Model", + filetypes=[("GGUF files", "*.gguf"), ("All files", "*.*")] + ) + if filename: + self.base_model_var.set(filename) + self._log(f"Selected base model: {filename}") + + def _browse_dataset(self): + """Browse for dataset file.""" + filename = filedialog.askopenfilename( + title="Select Dataset", + filetypes=[ + ("JSON files", "*.json"), + ("JSONL files", "*.jsonl"), + ("CSV files", "*.csv"), + ("Text files", "*.txt"), + ("All files", "*.*") + ] + ) + if filename: + self.dataset_var.set(filename) + self._log(f"Selected dataset: {filename}") + + def _start_training(self): + """Start the fine-tuning process.""" + if self.training_in_progress: + messagebox.showwarning("Training in Progress", "Training is already in progress!") + return + + # Validate inputs + if not self.base_model_var.get(): + messagebox.showerror("Error", "Please select a base model") + return + + if not self.dataset_var.get(): + messagebox.showerror("Error", "Please select a dataset") + return + + self.training_in_progress = True + self.status_label.config(text="Training in progress...", foreground="orange") + self._log("=" * 50) + self._log("Starting fine-tuning process...") + self._log(f"Base Model: {self.base_model_var.get()}") + self._log(f"Dataset: {self.dataset_var.get()}") + self._log(f"Method: {self.method_var.get()}") + self._log(f"Epochs: {self.epochs_var.get()}") + self._log("=" * 50) + + # TODO: Implement actual training logic + messagebox.showinfo("Coming Soon", "Fine-tuning functionality will be implemented soon!") + + self.training_in_progress = False + self.status_label.config(text="Training complete (stub)", foreground="green") + + def _stop_training(self): + """Stop the training process.""" + if not self.training_in_progress: + messagebox.showinfo("No Training", "No training in progress") + return + + self.training_in_progress = False + self.status_label.config(text="Training stopped", foreground="red") + self._log("Training stopped by user") + + def _save_config(self): + """Save training configuration to file.""" + config = { + "base_model": self.base_model_var.get(), + "output_model": self.output_model_var.get(), + "dataset": self.dataset_var.get(), + "format": self.format_var.get(), + "split": self.split_var.get(), + "method": self.method_var.get(), + "epochs": self.epochs_var.get(), + "batch_size": self.batch_var.get(), + "learning_rate": self.lr_var.get(), + "lora_rank": self.lora_rank_var.get(), + "lora_alpha": self.lora_alpha_var.get(), + "max_length": self.max_length_var.get(), + "warmup_steps": self.warmup_var.get(), + "device": self.device_var.get(), + "mixed_precision": self.mixed_precision_var.get(), + "gradient_checkpointing": self.grad_checkpoint_var.get() + } + + filename = filedialog.asksaveasfilename( + title="Save Training Config", + defaultextension=".json", + filetypes=[("JSON files", "*.json"), ("All files", "*.*")] + ) + + if filename: + with open(filename, 'w') as f: + json.dump(config, f, indent=2) + self._log(f"Config saved to {filename}") + messagebox.showinfo("Saved", f"Configuration saved to {filename}") + + def _load_config(self): + """Load training configuration from file.""" + filename = filedialog.askopenfilename( + title="Load Training Config", + filetypes=[("JSON files", "*.json"), ("All files", "*.*")] + ) + + if filename: + try: + with open(filename, 'r') as f: + config = json.load(f) + + # Load values from config + self.base_model_var.set(config.get("base_model", "")) + self.output_model_var.set(config.get("output_model", "my-finetuned-model")) + self.dataset_var.set(config.get("dataset", "")) + self.format_var.set(config.get("format", "alpaca")) + self.split_var.set(config.get("split", "90/10")) + self.method_var.set(config.get("method", "LoRA")) + self.epochs_var.set(config.get("epochs", "3")) + self.batch_var.set(config.get("batch_size", "4")) + self.lr_var.set(config.get("learning_rate", "2e-4")) + self.lora_rank_var.set(config.get("lora_rank", "8")) + self.lora_alpha_var.set(config.get("lora_alpha", "16")) + self.max_length_var.set(config.get("max_length", "512")) + self.warmup_var.set(config.get("warmup_steps", "100")) + self.device_var.set(config.get("device", "cuda")) + self.mixed_precision_var.set(config.get("mixed_precision", True)) + self.grad_checkpoint_var.set(config.get("gradient_checkpointing", True)) + + self._log(f"Config loaded from {filename}") + messagebox.showinfo("Loaded", f"Configuration loaded from {filename}") + + except Exception as e: + messagebox.showerror("Error", f"Failed to load config: {e}") + + def _validate_dataset(self): + """Validate the selected dataset.""" + if not self.dataset_var.get(): + messagebox.showerror("Error", "Please select a dataset first") + return + + dataset_path = self.dataset_var.get() + if not os.path.exists(dataset_path): + messagebox.showerror("Error", f"Dataset file not found: {dataset_path}") + return + + # TODO: Implement actual dataset validation + self._log(f"Validating dataset: {dataset_path}") + self._log("Dataset validation (stub) - would check format, size, etc.") + messagebox.showinfo("Validation", "Dataset validation complete (stub)") + + def _log(self, message: str): + """Add message to training log.""" + self.log_text.config(state=tk.NORMAL) + timestamp = datetime.now().strftime("%H:%M:%S") + self.log_text.insert(tk.END, f"[{timestamp}] {message}\n") + self.log_text.see(tk.END) + self.log_text.config(state=tk.DISABLED) + + def set_model(self, model_path: Optional[str]): + """Set the current model for fine-tuning.""" + self.current_model = model_path + if model_path: + self.base_model_var.set(model_path) + self._log(f"Model selected: {Path(model_path).name}") \ No newline at end of file diff --git a/grouped_download_gui.py b/grouped_download_gui.py new file mode 100644 index 0000000..dcdd749 --- /dev/null +++ b/grouped_download_gui.py @@ -0,0 +1,652 @@ +import tkinter as tk +from tkinter import ttk, messagebox +from typing import Dict, Any, Optional +import os +import subprocess +import platform +from grouped_download_manager import ( + GroupedDownloadManager, DownloadGroup, DownloadItem, DownloadStatus, FileSelectionDialog +) + + +class CollapsibleDownloadWidget: + """A collapsible widget for displaying a download group and its files.""" + + def __init__(self, parent: ttk.Frame, group: DownloadGroup, manager: GroupedDownloadManager): + self.parent = parent + self.group = group + self.manager = manager + self.expanded = group.expanded + + # Main container frame + self.main_frame = ttk.Frame(parent) + self.main_frame.pack(fill=tk.X, padx=5, pady=2) + + # Group header frame (always visible) + self.header_frame = ttk.Frame(self.main_frame, relief=tk.RAISED, borderwidth=1) + self.header_frame.pack(fill=tk.X) + + # Group details frame (collapsible) + self.details_frame = ttk.Frame(self.main_frame) + + self.build_header() + self.build_details() + self.update_display() + + # Set initial expansion state + if self.expanded: + self.show_details() + else: + self.hide_details() + + def build_header(self): + """Build the group header with overall progress and controls.""" + # Left side - expand/collapse button and info + left_frame = ttk.Frame(self.header_frame) + left_frame.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=5, pady=5) + + # Expand/collapse button + self.expand_btn = ttk.Button( + left_frame, + text="▼" if self.expanded else "▶", + width=3, + command=self.toggle_expansion + ) + self.expand_btn.pack(side=tk.LEFT, padx=(0, 5)) + + # Group info + info_frame = ttk.Frame(left_frame) + info_frame.pack(side=tk.LEFT, fill=tk.X, expand=True) + + # Title and status + title_frame = ttk.Frame(info_frame) + title_frame.pack(fill=tk.X) + + self.title_label = ttk.Label( + title_frame, + text=self.group.name, + font=("TkDefaultFont", 10, "bold") + ) + self.title_label.pack(side=tk.LEFT) + + self.status_label = ttk.Label( + title_frame, + text=self.group.status.value, + foreground=self.get_status_color(self.group.status) + ) + self.status_label.pack(side=tk.LEFT, padx=(10, 0)) + + # Progress info + progress_frame = ttk.Frame(info_frame) + progress_frame.pack(fill=tk.X, pady=(2, 0)) + + self.progress_label = ttk.Label(progress_frame, text="") + self.progress_label.pack(side=tk.LEFT) + + self.speed_label = ttk.Label(progress_frame, text="") + self.speed_label.pack(side=tk.LEFT, padx=(10, 0)) + + self.eta_label = ttk.Label(progress_frame, text="") + self.eta_label.pack(side=tk.LEFT, padx=(10, 0)) + + # Progress bar + self.progress_bar = ttk.Progressbar( + info_frame, + length=300, + mode='determinate' + ) + self.progress_bar.pack(fill=tk.X, pady=(2, 0)) + + # Right side - control buttons + control_frame = ttk.Frame(self.header_frame) + control_frame.pack(side=tk.RIGHT, padx=5, pady=5) + + self.pause_btn = ttk.Button( + control_frame, + text="Pause", + width=8, + command=self.pause_group + ) + self.pause_btn.pack(side=tk.LEFT, padx=2) + + self.resume_btn = ttk.Button( + control_frame, + text="Resume", + width=8, + command=self.resume_group + ) + self.resume_btn.pack(side=tk.LEFT, padx=2) + + self.cancel_btn = ttk.Button( + control_frame, + text="Cancel", + width=8, + command=self.cancel_group + ) + self.cancel_btn.pack(side=tk.LEFT, padx=2) + + self.remove_btn = ttk.Button( + control_frame, + text="Remove", + width=8, + command=self.remove_group + ) + self.remove_btn.pack(side=tk.LEFT, padx=2) + + def build_details(self): + """Build the collapsible details section with individual files.""" + # File list frame + files_frame = ttk.LabelFrame(self.details_frame, text="Files", padding=5) + files_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) + + # Create treeview for files + columns = ("status", "progress", "size", "speed") + self.file_tree = ttk.Treeview( + files_frame, + columns=columns, + show="tree headings", + height=8 + ) + self.file_tree.pack(fill=tk.BOTH, expand=True) + + # Configure columns + self.file_tree.heading("#0", text="Filename") + self.file_tree.heading("status", text="Status") + self.file_tree.heading("progress", text="Progress") + self.file_tree.heading("size", text="Size") + self.file_tree.heading("speed", text="Speed") + + self.file_tree.column("#0", width=200) + self.file_tree.column("status", width=100) + self.file_tree.column("progress", width=100) + self.file_tree.column("size", width=100) + self.file_tree.column("speed", width=100) + + # File controls + file_controls = ttk.Frame(files_frame) + file_controls.pack(fill=tk.X, pady=(5, 0)) + + ttk.Button( + file_controls, + text="Select All", + command=self.select_all_files + ).pack(side=tk.LEFT, padx=2) + + ttk.Button( + file_controls, + text="Deselect All", + command=self.deselect_all_files + ).pack(side=tk.LEFT, padx=2) + + ttk.Button( + file_controls, + text="Open Folder", + command=self.open_folder + ).pack(side=tk.RIGHT, padx=2) + + # Bind file selection + self.file_tree.bind("", self.on_file_click) + self.file_tree.bind("", self.on_file_double_click) + + # Populate files + self.populate_files() + + def populate_files(self): + """Populate the file tree with download items.""" + # Clear existing items + for item in self.file_tree.get_children(): + self.file_tree.delete(item) + + # Add files + for download_id, item in self.group.files.items(): + checkbox = "☑" if item.selected else "☐" + filename = f"{checkbox} {item.filename}" + + size_text = self.format_size(item.total_size) if item.total_size > 0 else "-" + if item.downloaded_size > 0 and item.total_size > 0: + size_text = f"{self.format_size(item.downloaded_size)} / {size_text}" + + speed_text = f"{self.format_size(item.speed)}/s" if item.speed > 0 else "-" + + tree_id = self.file_tree.insert( + "", "end", + text=filename, + values=( + item.status.value, + f"{item.progress:.1f}%", + size_text, + speed_text + ), + tags=(download_id,) + ) + + # Color code by status + if item.status == DownloadStatus.COMPLETED: + self.file_tree.item(tree_id, tags=(download_id, "completed")) + elif item.status == DownloadStatus.FAILED: + self.file_tree.item(tree_id, tags=(download_id, "failed")) + elif item.status == DownloadStatus.DOWNLOADING: + self.file_tree.item(tree_id, tags=(download_id, "downloading")) + + # Configure tag colors + self.file_tree.tag_configure("completed", background="#d4edda") + self.file_tree.tag_configure("failed", background="#f8d7da") + self.file_tree.tag_configure("downloading", background="#d1ecf1") + + def toggle_expansion(self): + """Toggle the expansion state of the widget.""" + self.expanded = not self.expanded + self.group.expanded = self.expanded # Update group state + + if self.expanded: + self.show_details() + else: + self.hide_details() + + self.expand_btn.config(text="▼" if self.expanded else "▶") + + def show_details(self): + """Show the details frame.""" + self.details_frame.pack(fill=tk.BOTH, expand=True, pady=(2, 0)) + + def hide_details(self): + """Hide the details frame.""" + self.details_frame.pack_forget() + + def update_display(self): + """Update the display with current group status.""" + # Update status + self.status_label.config( + text=self.group.status.value, + foreground=self.get_status_color(self.group.status) + ) + + # Update progress + self.progress_bar['value'] = self.group.progress + + # Update progress text + if self.group.total_size > 0: + progress_text = ( + f"{self.group.progress:.1f}% - " + f"{self.format_size(self.group.downloaded_size)} / " + f"{self.format_size(self.group.total_size)}" + ) + else: + progress_text = f"{self.group.progress:.1f}%" + + self.progress_label.config(text=progress_text) + + # Update speed and ETA + if self.group.active_speed > 0: + self.speed_label.config(text=f"Speed: {self.format_size(self.group.active_speed)}/s") + else: + self.speed_label.config(text="") + + if self.group.eta > 0: + self.eta_label.config(text=f"ETA: {self.format_time(self.group.eta)}") + else: + self.eta_label.config(text="") + + # Update file list if expanded + if self.expanded: + self.populate_files() + + # Update button states + self.update_button_states() + + def update_button_states(self): + """Update button states based on group status.""" + status = self.group.status + + if status == DownloadStatus.DOWNLOADING: + self.pause_btn.config(state="normal") + self.resume_btn.config(state="disabled") + self.cancel_btn.config(state="normal") + elif status == DownloadStatus.PAUSED: + self.pause_btn.config(state="disabled") + self.resume_btn.config(state="normal") + self.cancel_btn.config(state="normal") + elif status in [DownloadStatus.FAILED, DownloadStatus.AUTH_REQUIRED]: + self.pause_btn.config(state="disabled") + self.resume_btn.config(state="normal") + self.cancel_btn.config(state="disabled") + elif status == DownloadStatus.COMPLETED: + self.pause_btn.config(state="disabled") + self.resume_btn.config(state="disabled") + self.cancel_btn.config(state="disabled") + else: # QUEUED, CANCELLED + self.pause_btn.config(state="disabled") + self.resume_btn.config(state="normal") + self.cancel_btn.config(state="normal") + + def get_status_color(self, status: DownloadStatus) -> str: + """Get color for status display.""" + color_map = { + DownloadStatus.QUEUED: "blue", + DownloadStatus.DOWNLOADING: "green", + DownloadStatus.PAUSED: "orange", + DownloadStatus.COMPLETED: "green", + DownloadStatus.FAILED: "red", + DownloadStatus.CANCELLED: "gray", + DownloadStatus.AUTH_REQUIRED: "red" + } + return color_map.get(status, "black") + + def format_size(self, bytes_size: float) -> str: + """Format bytes to human readable size.""" + for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + if bytes_size < 1024.0: + return f"{bytes_size:.1f} {unit}" + bytes_size /= 1024.0 + return f"{bytes_size:.1f} PB" + + def format_time(self, seconds: int) -> str: + """Format seconds to human readable time.""" + if seconds < 60: + return f"{seconds}s" + elif seconds < 3600: + minutes = seconds // 60 + secs = seconds % 60 + return f"{minutes}m {secs}s" + else: + hours = seconds // 3600 + minutes = (seconds % 3600) // 60 + return f"{hours}h {minutes}m" + + def on_file_click(self, event): + """Handle file clicks for selection toggle.""" + item_id = self.file_tree.identify_row(event.y) + if item_id: + tags = self.file_tree.item(item_id, "tags") + if tags: + download_id = tags[0] + self.manager.toggle_file_selection(download_id) + + def on_file_double_click(self, event): + """Handle file double-click for individual file controls.""" + item_id = self.file_tree.identify_row(event.y) + if item_id: + tags = self.file_tree.item(item_id, "tags") + if tags: + download_id = tags[0] + # Show context menu or individual file controls + self.show_file_context_menu(event, download_id) + + def show_file_context_menu(self, event, download_id: str): + """Show context menu for individual file.""" + # Find the download item + item = self.group.files.get(download_id) + if not item: + return + + context_menu = tk.Menu(self.file_tree, tearoff=0) + + if item.selected: + context_menu.add_command(label="Deselect", command=lambda: self.manager.toggle_file_selection(download_id)) + else: + context_menu.add_command(label="Select", command=lambda: self.manager.toggle_file_selection(download_id)) + + context_menu.add_separator() + + if item.status == DownloadStatus.DOWNLOADING: + context_menu.add_command(label="Pause", command=lambda: self.manager.pause_download(download_id)) + elif item.is_resumable: + context_menu.add_command(label="Resume", command=lambda: self.manager.resume_download(download_id)) + + if item.status not in [DownloadStatus.COMPLETED]: + context_menu.add_command(label="Cancel", command=lambda: self.manager.cancel_download(download_id)) + + context_menu.add_separator() + context_menu.add_command(label="Open Folder", command=lambda: self.open_file_folder(item)) + + try: + context_menu.post(event.x_root, event.y_root) + finally: + context_menu.grab_release() + + def select_all_files(self): + """Select all files in the group.""" + for download_id, item in self.group.files.items(): + if not item.selected: + self.manager.toggle_file_selection(download_id) + + def deselect_all_files(self): + """Deselect all files in the group.""" + for download_id, item in self.group.files.items(): + if item.selected: + self.manager.toggle_file_selection(download_id) + + def pause_group(self): + """Pause all downloads in the group.""" + self.manager.pause_group(self.group.id) + + def resume_group(self): + """Resume all downloads in the group.""" + self.manager.resume_group(self.group.id) + + def cancel_group(self): + """Cancel all downloads in the group.""" + if messagebox.askyesno("Cancel Downloads", + f"Cancel all downloads in '{self.group.name}'?"): + self.manager.cancel_group(self.group.id) + + def remove_group(self): + """Remove the entire group.""" + if messagebox.askyesno("Remove Group", + f"Remove download group '{self.group.name}'?\\nThis will cancel any active downloads."): + self.manager.remove_group(self.group.id) + + def open_folder(self): + """Open the download folder.""" + if self.group.files: + # Get folder from first file + first_file = next(iter(self.group.files.values())) + folder = os.path.dirname(first_file.save_path) + self.open_directory(folder) + + def open_file_folder(self, item: DownloadItem): + """Open folder for specific file.""" + folder = os.path.dirname(item.save_path) + self.open_directory(folder) + + def open_directory(self, path: str): + """Open directory in file explorer.""" + if os.path.exists(path): + try: + if platform.system() == "Windows": + subprocess.run(["explorer", path]) + elif platform.system() == "Darwin": + subprocess.run(["open", path]) + else: + subprocess.run(["xdg-open", path]) + except Exception as e: + messagebox.showerror("Error", f"Could not open folder: {e}") + + def destroy(self): + """Clean up the widget.""" + self.main_frame.destroy() + + +class GroupedDownloadManagerTab: + """Main tab for the grouped download manager.""" + + def __init__(self, parent: ttk.Frame, download_manager: GroupedDownloadManager): + self.parent = parent + self.manager = download_manager + self.group_widgets: Dict[str, CollapsibleDownloadWidget] = {} + + # Register callbacks + self.manager.register_callback('on_progress', self._on_progress) + self.manager.register_callback('on_status_change', self._on_status_change) + self.manager.register_callback('on_complete', self._on_complete) + self.manager.register_callback('on_error', self._on_error) + self.manager.register_callback('on_remove', self._on_remove) + self.manager.register_callback('on_group_change', self._on_group_change) + + self._build_ui() + self._update_display() + + def _build_ui(self): + """Build the main UI.""" + # Top controls + controls_frame = ttk.Frame(self.parent) + controls_frame.pack(fill=tk.X, padx=10, pady=5) + + # Left controls + left_controls = ttk.Frame(controls_frame) + left_controls.pack(side=tk.LEFT) + + ttk.Button(left_controls, text="Clear Completed", + command=self._clear_completed).pack(side=tk.LEFT, padx=2) + ttk.Button(left_controls, text="Pause All", + command=self._pause_all).pack(side=tk.LEFT, padx=2) + ttk.Button(left_controls, text="Resume All", + command=self._resume_all).pack(side=tk.LEFT, padx=2) + ttk.Button(left_controls, text="Expand All", + command=self._expand_all).pack(side=tk.LEFT, padx=2) + ttk.Button(left_controls, text="Collapse All", + command=self._collapse_all).pack(side=tk.LEFT, padx=2) + + # Right status + right_controls = ttk.Frame(controls_frame) + right_controls.pack(side=tk.RIGHT) + + self.status_label = ttk.Label(right_controls, text="Groups: 0, Downloads: 0") + self.status_label.pack(side=tk.RIGHT) + + # Scrollable frame for download groups + canvas_frame = ttk.Frame(self.parent) + canvas_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=5) + + # Create canvas and scrollbar + self.canvas = tk.Canvas(canvas_frame, bg='white') + self.scrollbar = ttk.Scrollbar(canvas_frame, orient="vertical", command=self.canvas.yview) + self.scrollable_frame = ttk.Frame(self.canvas) + + self.scrollable_frame.bind( + "", + lambda e: self.canvas.configure(scrollregion=self.canvas.bbox("all")) + ) + + self.canvas.create_window((0, 0), window=self.scrollable_frame, anchor="nw") + self.canvas.configure(yscrollcommand=self.scrollbar.set) + + self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + self.scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # Bind mouse wheel + self.canvas.bind("", self._on_mouse_wheel) + + def _on_mouse_wheel(self, event): + """Handle mouse wheel scrolling.""" + self.canvas.yview_scroll(int(-1 * (event.delta / 120)), "units") + + def add_download_group(self, group: DownloadGroup): + """Add a new download group widget.""" + if group.id not in self.group_widgets: + widget = CollapsibleDownloadWidget(self.scrollable_frame, group, self.manager) + self.group_widgets[group.id] = widget + + def _on_progress(self, item): + """Handle progress updates.""" + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: self._update_group_display(item.group_id)) + + def _on_status_change(self, item): + """Handle status changes.""" + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: self._update_group_display(item.group_id)) + + def _on_complete(self, item): + """Handle completion.""" + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: self._update_group_display(item.group_id)) + + def _on_error(self, item): + """Handle errors.""" + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: self._update_group_display(item.group_id)) + + def _on_remove(self, group): + """Handle group removal.""" + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: self._remove_group_widget(group.id)) + + def _on_group_change(self, group): + """Handle group changes.""" + if hasattr(self.parent, 'after'): + self.parent.after(0, lambda: self._update_group_display(group.id)) + + def _update_group_display(self, group_id: str): + """Update display for a specific group.""" + if group_id in self.group_widgets and group_id in self.manager.groups: + widget = self.group_widgets[group_id] + widget.update_display() + + def _remove_group_widget(self, group_id: str): + """Remove a group widget.""" + if group_id in self.group_widgets: + widget = self.group_widgets[group_id] + widget.destroy() + del self.group_widgets[group_id] + + def _clear_completed(self): + """Clear completed groups.""" + self.manager.clear_completed_groups() + + def _pause_all(self): + """Pause all active downloads.""" + for group in self.manager.get_all_groups(): + self.manager.pause_group(group.id) + + def _resume_all(self): + """Resume all paused downloads.""" + for group in self.manager.get_all_groups(): + self.manager.resume_group(group.id) + + def _expand_all(self): + """Expand all groups.""" + for widget in self.group_widgets.values(): + if not widget.expanded: + widget.toggle_expansion() + + def _collapse_all(self): + """Collapse all groups.""" + for widget in self.group_widgets.values(): + if widget.expanded: + widget.toggle_expansion() + + def _update_display(self): + """Update the display periodically.""" + try: + # Update status summary + all_groups = self.manager.get_all_groups() + total_files = sum(len(group.files) for group in all_groups) + active_downloads = sum( + 1 for group in all_groups + for item in group.files.values() + if item.status == DownloadStatus.DOWNLOADING + ) + + self.status_label.config(text=f"Groups: {len(all_groups)}, Downloads: {active_downloads} active / {total_files} total") + + # Add any new groups + for group in all_groups: + if group.id not in self.group_widgets: + self.add_download_group(group) + + # Remove widgets for deleted groups + to_remove = [] + for group_id in self.group_widgets: + if group_id not in self.manager.groups: + to_remove.append(group_id) + + for group_id in to_remove: + self._remove_group_widget(group_id) + + except Exception as e: + print(f"Error updating grouped download display: {e}") + finally: + # Schedule next update + if hasattr(self.parent, 'after'): + self.parent.after(200, self._update_display) # Update every 200ms \ No newline at end of file diff --git a/grouped_download_manager.py b/grouped_download_manager.py new file mode 100644 index 0000000..c6ef5bf --- /dev/null +++ b/grouped_download_manager.py @@ -0,0 +1,774 @@ +import os +import requests +import threading +import time +import tkinter as tk +from tkinter import ttk, messagebox +from typing import Optional, Dict, Any, List, Callable, Set +from dataclasses import dataclass, field +from enum import Enum +import json +from datetime import datetime +import queue +import uuid + + +class DownloadStatus(Enum): + """Download status enumeration.""" + QUEUED = "Queued" + DOWNLOADING = "Downloading" + PAUSED = "Paused" + COMPLETED = "Completed" + FAILED = "Failed" + CANCELLED = "Cancelled" + AUTH_REQUIRED = "Auth Required" + + +@dataclass +class DownloadItem: + """Represents a single file download.""" + id: str + group_id: str + repo_id: str + filename: str + url: str + save_path: str + total_size: int = 0 + downloaded_size: int = 0 + status: DownloadStatus = DownloadStatus.QUEUED + error_message: str = "" + start_time: Optional[float] = None + end_time: Optional[float] = None + speed: float = 0.0 + eta: int = 0 + headers: Dict[str, str] = field(default_factory=dict) + resume_position: int = 0 + selected: bool = True # Whether this file should be downloaded + + @property + def progress(self) -> float: + """Calculate download progress percentage.""" + if self.total_size > 0: + return (self.downloaded_size / self.total_size) * 100 + return 0.0 + + @property + def is_resumable(self) -> bool: + """Check if download can be resumed.""" + return self.status in [DownloadStatus.PAUSED, DownloadStatus.FAILED] + + +@dataclass +class DownloadGroup: + """Represents a group of related downloads (e.g., all files from one model).""" + id: str + repo_id: str + name: str + description: str = "" + created_time: float = field(default_factory=time.time) + files: Dict[str, DownloadItem] = field(default_factory=dict) + expanded: bool = False + + @property + def total_size(self) -> int: + """Total size of all selected files in group.""" + return sum(item.total_size for item in self.files.values() if item.selected) + + @property + def downloaded_size(self) -> int: + """Total downloaded size of all selected files.""" + return sum(item.downloaded_size for item in self.files.values() if item.selected) + + @property + def progress(self) -> float: + """Overall progress percentage for the group.""" + if self.total_size > 0: + return (self.downloaded_size / self.total_size) * 100 + return 0.0 + + @property + def status(self) -> DownloadStatus: + """Overall status of the group.""" + selected_files = [item for item in self.files.values() if item.selected] + if not selected_files: + return DownloadStatus.COMPLETED + + statuses = [item.status for item in selected_files] + + # If any are downloading, group is downloading + if DownloadStatus.DOWNLOADING in statuses: + return DownloadStatus.DOWNLOADING + + # If all completed, group is completed + if all(s == DownloadStatus.COMPLETED for s in statuses): + return DownloadStatus.COMPLETED + + # If any failed, group shows failed + if DownloadStatus.FAILED in statuses or DownloadStatus.AUTH_REQUIRED in statuses: + return DownloadStatus.FAILED + + # If any cancelled, show cancelled + if DownloadStatus.CANCELLED in statuses: + return DownloadStatus.CANCELLED + + # If any paused, show paused + if DownloadStatus.PAUSED in statuses: + return DownloadStatus.PAUSED + + # Otherwise queued + return DownloadStatus.QUEUED + + @property + def active_speed(self) -> float: + """Combined download speed of active files.""" + return sum(item.speed for item in self.files.values() + if item.status == DownloadStatus.DOWNLOADING and item.selected) + + @property + def eta(self) -> int: + """Estimated time to completion for the group.""" + if self.active_speed > 0: + remaining = self.total_size - self.downloaded_size + return int(remaining / self.active_speed) + return 0 + + +class GroupedDownloadManager: + """Manages grouped downloads with pause/resume support.""" + + def __init__(self, max_concurrent: int = 3): + self.groups: Dict[str, DownloadGroup] = {} + self.download_queue: queue.Queue = queue.Queue() + self.active_downloads: Dict[str, threading.Thread] = {} + self.max_concurrent = max_concurrent + self.callbacks: Dict[str, List[Callable]] = { + 'on_progress': [], + 'on_status_change': [], + 'on_complete': [], + 'on_error': [], + 'on_remove': [], + 'on_group_change': [] + } + self._stop_flags: Dict[str, threading.Event] = {} + self._pause_flags: Dict[str, threading.Event] = {} + self._worker_thread = threading.Thread(target=self._process_queue, daemon=True) + self._worker_thread.start() + + def create_download_group(self, repo_id: str, name: str, description: str = "") -> str: + """Create a new download group.""" + group_id = str(uuid.uuid4()) + group = DownloadGroup( + id=group_id, + repo_id=repo_id, + name=name, + description=description + ) + self.groups[group_id] = group + self._trigger_callback('on_group_change', group) + return group_id + + def add_file_to_group(self, group_id: str, filename: str, url: str, save_path: str, + headers: Optional[Dict[str, str]] = None, selected: bool = True) -> str: + """Add a file to an existing download group.""" + if group_id not in self.groups: + raise ValueError(f"Group {group_id} not found") + + group = self.groups[group_id] + download_id = f"{group_id}_{filename}_{int(time.time())}" + + item = DownloadItem( + id=download_id, + group_id=group_id, + repo_id=group.repo_id, + filename=filename, + url=url, + save_path=save_path, + headers=headers or {}, + selected=selected + ) + + group.files[download_id] = item + + # Queue for download if selected + if selected: + self.download_queue.put(download_id) + + self._trigger_callback('on_status_change', item) + self._trigger_callback('on_group_change', group) + + return download_id + + def toggle_file_selection(self, download_id: str): + """Toggle whether a file should be downloaded.""" + for group in self.groups.values(): + if download_id in group.files: + item = group.files[download_id] + item.selected = not item.selected + + if item.selected and item.status == DownloadStatus.QUEUED: + # Add to queue if now selected + self.download_queue.put(download_id) + elif not item.selected and item.status in [DownloadStatus.QUEUED, DownloadStatus.PAUSED]: + # Cancel if unselected and not yet completed + self.cancel_download(download_id) + + self._trigger_callback('on_status_change', item) + self._trigger_callback('on_group_change', group) + break + + def pause_download(self, download_id: str): + """Pause a download.""" + if download_id in self._pause_flags: + self._pause_flags[download_id].set() + + for group in self.groups.values(): + if download_id in group.files: + item = group.files[download_id] + item.status = DownloadStatus.PAUSED + self._trigger_callback('on_status_change', item) + self._trigger_callback('on_group_change', group) + break + + def pause_group(self, group_id: str): + """Pause all downloads in a group.""" + if group_id in self.groups: + group = self.groups[group_id] + for download_id in group.files: + if group.files[download_id].selected: + self.pause_download(download_id) + + def resume_download(self, download_id: str): + """Resume a paused download.""" + for group in self.groups.values(): + if download_id in group.files: + item = group.files[download_id] + if item.is_resumable and item.selected: + item.status = DownloadStatus.QUEUED + item.resume_position = item.downloaded_size + self.download_queue.put(download_id) + self._trigger_callback('on_status_change', item) + self._trigger_callback('on_group_change', group) + break + + def resume_group(self, group_id: str): + """Resume all paused downloads in a group.""" + if group_id in self.groups: + group = self.groups[group_id] + for download_id in group.files: + item = group.files[download_id] + if item.is_resumable and item.selected: + self.resume_download(download_id) + + def cancel_download(self, download_id: str): + """Cancel a download.""" + if download_id in self._stop_flags: + self._stop_flags[download_id].set() + + for group in self.groups.values(): + if download_id in group.files: + item = group.files[download_id] + item.status = DownloadStatus.CANCELLED + + # Remove partial file + if os.path.exists(item.save_path): + try: + os.remove(item.save_path) + except Exception: + pass + + self._trigger_callback('on_status_change', item) + self._trigger_callback('on_group_change', group) + break + + def cancel_group(self, group_id: str): + """Cancel all downloads in a group.""" + if group_id in self.groups: + group = self.groups[group_id] + for download_id in group.files: + self.cancel_download(download_id) + + def remove_group(self, group_id: str): + """Remove an entire download group.""" + if group_id in self.groups: + group = self.groups[group_id] + + # Cancel all active downloads first + for download_id in group.files: + if download_id in self._stop_flags: + self._stop_flags[download_id].set() + + # Remove the group + del self.groups[group_id] + self._trigger_callback('on_remove', group) + + def _process_queue(self): + """Process download queue.""" + while True: + # Check if we can start more downloads + if len(self.active_downloads) < self.max_concurrent: + try: + download_id = self.download_queue.get(timeout=1) + + # Find the download item + item = None + for group in self.groups.values(): + if download_id in group.files: + item = group.files[download_id] + break + + if item and item.selected: + thread = threading.Thread( + target=self._download_file, + args=(download_id,), + daemon=True + ) + self.active_downloads[download_id] = thread + thread.start() + + except queue.Empty: + pass + + # Clean up finished downloads + finished = [] + for download_id, thread in self.active_downloads.items(): + if not thread.is_alive(): + finished.append(download_id) + + for download_id in finished: + del self.active_downloads[download_id] + + time.sleep(0.5) + + def _download_file(self, download_id: str): + """Download a file with resume support.""" + # Find the item + item = None + group = None + for g in self.groups.values(): + if download_id in g.files: + item = g.files[download_id] + group = g + break + + if not item or not item.selected: + return + + # Create flags for this download + self._stop_flags[download_id] = threading.Event() + self._pause_flags[download_id] = threading.Event() + + try: + # Update status + item.status = DownloadStatus.DOWNLOADING + item.start_time = time.time() + self._trigger_callback('on_status_change', item) + self._trigger_callback('on_group_change', group) + + # Create directory if needed + os.makedirs(os.path.dirname(item.save_path), exist_ok=True) + + # Setup headers for resume + headers = item.headers.copy() + if item.resume_position > 0: + headers['Range'] = f'bytes={item.resume_position}-' + + # Make request + response = requests.get(item.url, headers=headers, stream=True, timeout=30) + + # Check for authentication issues + if response.status_code == 401 or response.status_code == 403: + item.status = DownloadStatus.AUTH_REQUIRED + item.error_message = f"Authentication failed: {response.status_code}" + self._trigger_callback('on_error', item) + self._trigger_callback('on_group_change', group) + return + + response.raise_for_status() + + # Get total size + if item.resume_position == 0: + item.total_size = int(response.headers.get('content-length', 0)) + + # Open file for writing (append if resuming) + mode = 'ab' if item.resume_position > 0 else 'wb' + + # Optimize chunk size based on file size + base_chunk_size = 1024 * 1024 # 1MB base chunk + if item.total_size > 100 * 1024 * 1024: # Files > 100MB + chunk_size = base_chunk_size * 4 # 4MB chunks + elif item.total_size > 10 * 1024 * 1024: # Files > 10MB + chunk_size = base_chunk_size * 2 # 2MB chunks + else: + chunk_size = base_chunk_size # 1MB chunks + + # Use buffered writing for better performance + buffer_size = chunk_size * 8 + + with open(item.save_path, mode, buffering=buffer_size) as f: + last_update = time.time() + bytes_since_update = 0 + update_interval = 0.5 # Update UI every 0.5 seconds + + for chunk in response.iter_content(chunk_size=chunk_size): + # Check stop flag + if self._stop_flags[download_id].is_set(): + return + + # Check pause flag + if self._pause_flags[download_id].is_set(): + item.status = DownloadStatus.PAUSED + self._trigger_callback('on_status_change', item) + self._trigger_callback('on_group_change', group) + return + + if chunk: + f.write(chunk) + item.downloaded_size += len(chunk) + bytes_since_update += len(chunk) + + # Calculate speed and ETA + current_time = time.time() + time_diff = current_time - last_update + + if time_diff >= update_interval: + item.speed = bytes_since_update / time_diff + + if item.speed > 0 and item.total_size > item.downloaded_size: + remaining = item.total_size - item.downloaded_size + item.eta = int(remaining / item.speed) + + self._trigger_callback('on_progress', item) + self._trigger_callback('on_group_change', group) + last_update = current_time + bytes_since_update = 0 + + # Force flush for USB drives + f.flush() + os.fsync(f.fileno()) + + # Download completed + item.status = DownloadStatus.COMPLETED + item.end_time = time.time() + self._trigger_callback('on_complete', item) + self._trigger_callback('on_group_change', group) + + except requests.exceptions.RequestException as e: + item.status = DownloadStatus.FAILED + item.error_message = str(e) + self._trigger_callback('on_error', item) + self._trigger_callback('on_group_change', group) + + except Exception as e: + item.status = DownloadStatus.FAILED + item.error_message = f"Unexpected error: {e}" + self._trigger_callback('on_error', item) + self._trigger_callback('on_group_change', group) + + finally: + # Clean up flags + if download_id in self._stop_flags: + del self._stop_flags[download_id] + if download_id in self._pause_flags: + del self._pause_flags[download_id] + + self._trigger_callback('on_status_change', item) + self._trigger_callback('on_group_change', group) + + def register_callback(self, event: str, callback: Callable): + """Register a callback for download events.""" + if event in self.callbacks: + self.callbacks[event].append(callback) + + def _trigger_callback(self, event: str, item): + """Trigger callbacks for an event.""" + for callback in self.callbacks.get(event, []): + try: + callback(item) + except Exception as e: + print(f"Callback error: {e}") + + def get_all_groups(self) -> List[DownloadGroup]: + """Get all download groups.""" + return list(self.groups.values()) + + def clear_completed_groups(self): + """Clear completed download groups.""" + to_remove = [] + for group_id, group in self.groups.items(): + if group.status == DownloadStatus.COMPLETED: + to_remove.append(group_id) + + for group_id in to_remove: + self.remove_group(group_id) + + +class FileSelectionDialog: + """Dialog for selecting which files to download from a model.""" + + def __init__(self, parent: tk.Tk, repo_id: str, files: List[Dict[str, Any]], + title: str = "Select Files to Download"): + self.parent = parent + self.repo_id = repo_id + self.files = files + self.selected_files = [] + self.result = None + + # Create dialog + self.dialog = tk.Toplevel(parent) + self.dialog.title(title) + self.dialog.geometry("800x600") + self.dialog.resizable(True, True) + + # Make modal + self.dialog.transient(parent) + self.dialog.grab_set() + + self.build_ui() + self.center_dialog() + + def build_ui(self): + """Build the file selection dialog UI.""" + main_frame = ttk.Frame(self.dialog, padding=10) + main_frame.pack(fill=tk.BOTH, expand=True) + + # Title and info + title_frame = ttk.Frame(main_frame) + title_frame.pack(fill=tk.X, pady=(0, 10)) + + ttk.Label(title_frame, text=f"Select files to download from:", + font=("TkDefaultFont", 10, "bold")).pack(anchor=tk.W) + ttk.Label(title_frame, text=self.repo_id, + font=("TkDefaultFont", 12, "bold")).pack(anchor=tk.W) + + # Selection controls + controls_frame = ttk.Frame(main_frame) + controls_frame.pack(fill=tk.X, pady=(0, 10)) + + ttk.Button(controls_frame, text="Select All", + command=self.select_all).pack(side=tk.LEFT, padx=(0, 5)) + ttk.Button(controls_frame, text="Deselect All", + command=self.deselect_all).pack(side=tk.LEFT, padx=5) + ttk.Button(controls_frame, text="Select GGUF Only", + command=self.select_gguf_only).pack(side=tk.LEFT, padx=5) + + # File list + list_frame = ttk.Frame(main_frame) + list_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10)) + + # Create treeview with file info + columns = ("size", "type", "modified") + self.file_tree = ttk.Treeview(list_frame, columns=columns, show="tree headings", height=15) + self.file_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + + # Configure columns + self.file_tree.heading("#0", text="Filename") + self.file_tree.heading("size", text="Size") + self.file_tree.heading("type", text="Type") + self.file_tree.heading("modified", text="Modified") + + self.file_tree.column("#0", width=300) + self.file_tree.column("size", width=100) + self.file_tree.column("type", width=80) + self.file_tree.column("modified", width=150) + + # Scrollbar + scrollbar = ttk.Scrollbar(list_frame, orient="vertical", command=self.file_tree.yview) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self.file_tree.configure(yscrollcommand=scrollbar.set) + + # Populate file list + self.file_vars = {} + self.populate_file_list() + + # Bind click events + self.file_tree.bind("", self.on_tree_click) + + # Summary frame + summary_frame = ttk.LabelFrame(main_frame, text="Download Summary", padding=10) + summary_frame.pack(fill=tk.X, pady=(0, 10)) + + self.summary_label = ttk.Label(summary_frame, text="") + self.summary_label.pack(anchor=tk.W) + + # Buttons + button_frame = ttk.Frame(main_frame) + button_frame.pack(fill=tk.X) + + ttk.Button(button_frame, text="Download Selected", + command=self.download_selected).pack(side=tk.RIGHT, padx=(5, 0)) + ttk.Button(button_frame, text="Cancel", + command=self.cancel).pack(side=tk.RIGHT) + + # Update summary + self.update_summary() + + def populate_file_list(self): + """Populate the file list with checkboxes.""" + for i, file_info in enumerate(self.files): + filename = file_info.get('rfilename', file_info.get('path', 'Unknown')) + size = self.format_size(file_info.get('size', 0)) + file_type = self.get_file_type(filename) + modified = file_info.get('lastModified', 'Unknown') + + # Create selection variable + var = tk.BooleanVar(value=self.should_auto_select(filename)) + self.file_vars[filename] = var + + # Insert into tree + item_id = self.file_tree.insert( + "", "end", + text=f"☐ {filename}", + values=(size, file_type, modified), + tags=(filename,) + ) + + def should_auto_select(self, filename: str) -> bool: + """Determine if a file should be auto-selected.""" + # Auto-select GGUF files and small files like README + if filename.lower().endswith('.gguf'): + return True + if filename.lower() in ['readme.md', 'readme', 'license', 'license.txt']: + return True + if any(filename.lower().endswith(ext) for ext in ['.json', '.txt']) and 'config' in filename.lower(): + return True + return False + + def get_file_type(self, filename: str) -> str: + """Get file type from extension.""" + ext = os.path.splitext(filename)[1].lower() + type_map = { + '.gguf': 'GGUF', + '.safetensors': 'SafeTensors', + '.bin': 'Binary', + '.json': 'Config', + '.md': 'Markdown', + '.txt': 'Text', + '.py': 'Python', + '.yaml': 'YAML', + '.yml': 'YAML' + } + return type_map.get(ext, 'Other') + + def format_size(self, bytes_size: int) -> str: + """Format bytes to human readable size.""" + for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + if bytes_size < 1024.0: + return f"{bytes_size:.1f} {unit}" + bytes_size /= 1024.0 + return f"{bytes_size:.1f} PB" + + def on_tree_click(self, event): + """Handle tree item clicks.""" + item = self.file_tree.identify_row(event.y) + if item: + tags = self.file_tree.item(item, "tags") + if tags: + filename = tags[0] + if filename in self.file_vars: + # Toggle selection + current = self.file_vars[filename].get() + self.file_vars[filename].set(not current) + self.update_file_display(item, filename) + self.update_summary() + + def update_file_display(self, item_id: str, filename: str): + """Update file display based on selection.""" + selected = self.file_vars[filename].get() + checkbox = "☑" if selected else "☐" + text = f"{checkbox} {filename}" + self.file_tree.item(item_id, text=text) + + def select_all(self): + """Select all files.""" + for var in self.file_vars.values(): + var.set(True) + self.refresh_display() + + def deselect_all(self): + """Deselect all files.""" + for var in self.file_vars.values(): + var.set(False) + self.refresh_display() + + def select_gguf_only(self): + """Select only GGUF files and essential files.""" + for filename, var in self.file_vars.items(): + should_select = ( + filename.lower().endswith('.gguf') or + filename.lower() in ['readme.md', 'readme', 'license'] or + 'config' in filename.lower() + ) + var.set(should_select) + self.refresh_display() + + def refresh_display(self): + """Refresh the display of all items.""" + for item in self.file_tree.get_children(): + tags = self.file_tree.item(item, "tags") + if tags: + filename = tags[0] + self.update_file_display(item, filename) + self.update_summary() + + def update_summary(self): + """Update the download summary.""" + selected_count = sum(1 for var in self.file_vars.values() if var.get()) + total_count = len(self.file_vars) + + selected_size = sum( + file_info.get('size', 0) + for i, file_info in enumerate(self.files) + if self.file_vars.get(file_info.get('rfilename', file_info.get('path', '')), tk.BooleanVar()).get() + ) + + size_text = self.format_size(selected_size) + self.summary_label.config( + text=f"Selected: {selected_count}/{total_count} files, Total size: {size_text}" + ) + + def download_selected(self): + """Start download of selected files.""" + self.selected_files = [ + (filename, file_info) + for filename, var in self.file_vars.items() + if var.get() + for file_info in self.files + if file_info.get('rfilename', file_info.get('path', '')) == filename + ] + + if not self.selected_files: + messagebox.showwarning("No Selection", "Please select at least one file to download.") + return + + self.result = 'download' + self.dialog.destroy() + + def cancel(self): + """Cancel the dialog.""" + self.result = 'cancel' + self.dialog.destroy() + + def center_dialog(self): + """Center dialog on parent.""" + self.dialog.update_idletasks() + + # Get parent position and size + parent_x = self.parent.winfo_x() + parent_y = self.parent.winfo_y() + parent_width = self.parent.winfo_width() + parent_height = self.parent.winfo_height() + + # Get dialog size + dialog_width = self.dialog.winfo_width() + dialog_height = self.dialog.winfo_height() + + # Calculate center position + x = parent_x + (parent_width - dialog_width) // 2 + y = parent_y + (parent_height - dialog_height) // 2 + + self.dialog.geometry(f"+{x}+{y}") + + def show(self): + """Show the dialog and return the result.""" + self.dialog.wait_window() + return self.result, self.selected_files \ No newline at end of file diff --git a/hf_downloader.py b/hf_downloader.py new file mode 100644 index 0000000..51ccc3d --- /dev/null +++ b/hf_downloader.py @@ -0,0 +1,586 @@ +import os +import requests +import tkinter as tk +from tkinter import ttk, messagebox, filedialog +from typing import Optional, List, Dict, Any +import threading +import json +from pathlib import Path +from dotenv import load_dotenv + +# Load the HuggingFace API key from HUGGINGFACE.env +load_dotenv("HUGGINGFACE.env") + + +class HuggingFaceAPI: + """Web API-based HuggingFace interface using direct HTTP requests.""" + + def __init__(self, api_key: Optional[str] = None, organization: Optional[str] = None): + # Ensure API key is properly cleaned of whitespace and newlines + raw_key = api_key or os.getenv("HF_API_KEY", "") + self.api_key = raw_key.strip().replace('\n', '').replace('\r', '') + if not self.api_key: + raise ValueError("HuggingFace API key not found. Please set HF_API_KEY in HUGGINGFACE.env") + + self.organization = organization.strip() if organization else None + self.base_url = "https://huggingface.co" + self.headers = {"Authorization": f"Bearer {self.api_key}"} + + # Add organization header if specified + if self.organization: + self.headers["X-Organization"] = self.organization + + def search_models(self, query: str = "", limit: int = 50, sort: str = "downloads") -> List[Dict[str, Any]]: + """Search for models using the web API.""" + url = f"{self.base_url}/api/models" + params = { + "limit": limit, + "sort": sort, + "direction": -1, + "full": True + } + if query: + params["search"] = query + + try: + response = requests.get(url, params=params, headers=self.headers, timeout=30) + response.raise_for_status() + return response.json() + except Exception as e: + print(f"Error searching models: {e}") + return [] + + def search_datasets(self, query: str = "", limit: int = 50, sort: str = "downloads") -> List[Dict[str, Any]]: + """Search for datasets using the web API.""" + url = f"{self.base_url}/api/datasets" + params = { + "limit": limit, + "sort": sort, + "direction": -1, + "full": True + } + if query: + params["search"] = query + + try: + response = requests.get(url, params=params, headers=self.headers, timeout=30) + response.raise_for_status() + return response.json() + except Exception as e: + print(f"Error searching datasets: {e}") + return [] + + def get_model_files(self, repo_id: str) -> List[Dict[str, Any]]: + """Get list of files in a model repository.""" + url = f"{self.base_url}/api/models/{repo_id}" + try: + response = requests.get(url, headers=self.headers, timeout=30) + response.raise_for_status() + data = response.json() + return data.get("siblings", []) + except Exception as e: + print(f"Error getting model files: {e}") + return [] + + def download_file(self, repo_id: str, filename: str, save_path: str, + progress_callback=None) -> bool: + """Download a file from HuggingFace.""" + url = f"{self.base_url}/{repo_id}/resolve/main/{filename}" + + try: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + response = requests.get(url, headers=self.headers, stream=True, timeout=30) + response.raise_for_status() + + total_size = int(response.headers.get('content-length', 0)) + downloaded = 0 + + with open(save_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + downloaded += len(chunk) + + if progress_callback and total_size > 0: + progress = (downloaded / total_size) * 100 + progress_callback(progress, downloaded, total_size) + + return True + + except Exception as e: + print(f"Error downloading file: {e}") + return False + + +class HuggingFaceDownloaderGUI: + """GUI for HuggingFace model and dataset search and download.""" + + def __init__(self, root: tk.Tk): + self.root = root + self.root.title("HuggingFace Downloader") + self.root.geometry("1200x700") + + # Initialize API + try: + self.api = HuggingFaceAPI() + except ValueError as e: + messagebox.showerror("API Key Error", str(e)) + self.api = None + + # Search variables + self.search_query = tk.StringVar() + self.search_type = tk.StringVar(value="Models") + self.filter_most_downloaded = tk.BooleanVar(value=True) + self.filter_most_liked = tk.BooleanVar(value=False) + self.filter_size = tk.BooleanVar(value=False) + + # Current results storage + self.current_results = [] + + self._build_ui() + + def _build_ui(self): + """Build the main UI.""" + # Search bar frame + search_frame = ttk.Frame(self.root) + search_frame.pack(fill=tk.X, padx=10, pady=10) + + # Search entry + self.search_entry = ttk.Entry(search_frame, textvariable=self.search_query, width=60) + self.search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) + self.search_entry.bind("", lambda e: self._perform_search()) + + # Search type dropdown + self.type_dropdown = ttk.Combobox(search_frame, textvariable=self.search_type, + values=["Models", "Datasets"], + state="readonly", width=15) + self.type_dropdown.pack(side=tk.LEFT, padx=(10, 0)) + + # Search button + self.search_button = ttk.Button(search_frame, text="Search", command=self._perform_search) + self.search_button.pack(side=tk.LEFT, padx=(10, 0)) + + # Results frame with treeview + results_frame = ttk.Frame(self.root) + results_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=(0, 10)) + + # Create treeview with columns + columns = ("creator", "name", "description", "keywords", "size", "metadata") + self.results_tree = ttk.Treeview(results_frame, columns=columns, show="headings", height=20) + + # Define column headings and widths + self.results_tree.heading("creator", text="Creator") + self.results_tree.heading("name", text="Name") + self.results_tree.heading("description", text="Description") + self.results_tree.heading("keywords", text="Keywords") + self.results_tree.heading("size", text="Size") + self.results_tree.heading("metadata", text="Metadata") + + self.results_tree.column("creator", width=150) + self.results_tree.column("name", width=200) + self.results_tree.column("description", width=300) + self.results_tree.column("keywords", width=150) + self.results_tree.column("size", width=100) + self.results_tree.column("metadata", width=200) + + # Scrollbars + vsb = ttk.Scrollbar(results_frame, orient="vertical", command=self.results_tree.yview) + hsb = ttk.Scrollbar(results_frame, orient="horizontal", command=self.results_tree.xview) + self.results_tree.configure(yscrollcommand=vsb.set, xscrollcommand=hsb.set) + + self.results_tree.grid(row=0, column=0, sticky="nsew") + vsb.grid(row=0, column=1, sticky="ns") + hsb.grid(row=1, column=0, sticky="ew") + + results_frame.grid_rowconfigure(0, weight=1) + results_frame.grid_columnconfigure(0, weight=1) + + # Bind double-click to download + self.results_tree.bind("", self._on_item_double_click) + + # Filter footer frame + filter_frame = ttk.Frame(self.root) + filter_frame.pack(fill=tk.X, padx=10, pady=(0, 10)) + + ttk.Label(filter_frame, text="Filter:").pack(side=tk.LEFT) + + ttk.Checkbutton(filter_frame, text="Most Downloaded", + variable=self.filter_most_downloaded, + command=self._update_filters).pack(side=tk.LEFT, padx=(10, 0)) + + ttk.Checkbutton(filter_frame, text="Most Liked", + variable=self.filter_most_liked, + command=self._update_filters).pack(side=tk.LEFT, padx=(10, 0)) + + ttk.Checkbutton(filter_frame, text="Size", + variable=self.filter_size, + command=self._update_filters).pack(side=tk.LEFT, padx=(10, 0)) + + # Status bar + self.status_var = tk.StringVar(value="Ready") + status_bar = ttk.Label(self.root, textvariable=self.status_var, relief=tk.SUNKEN) + status_bar.pack(fill=tk.X, side=tk.BOTTOM) + + # Download button + download_frame = ttk.Frame(self.root) + download_frame.pack(fill=tk.X, padx=10, pady=(0, 10)) + + self.download_button = ttk.Button(download_frame, text="Download Selected", + command=self._download_selected) + self.download_button.pack(side=tk.RIGHT) + + def _update_filters(self): + """Update filter settings and re-sort results if needed.""" + # Ensure at least one filter is selected + if not any([self.filter_most_downloaded.get(), + self.filter_most_liked.get(), + self.filter_size.get()]): + self.filter_most_downloaded.set(True) + + def _perform_search(self): + """Perform the search based on current settings.""" + if not self.api: + messagebox.showerror("Error", "API not initialized") + return + + query = self.search_query.get().strip() + search_type = self.search_type.get() + + # Determine sort parameter + sort = "downloads" + if self.filter_most_liked.get() and not self.filter_most_downloaded.get(): + sort = "likes" + elif self.filter_size.get() and not self.filter_most_downloaded.get() and not self.filter_most_liked.get(): + sort = "lastModified" + + self.status_var.set(f"Searching {search_type.lower()}...") + self.search_button.config(state="disabled") + + # Clear previous results + for item in self.results_tree.get_children(): + self.results_tree.delete(item) + + # Perform search in thread + thread = threading.Thread(target=self._search_thread, + args=(query, search_type, sort)) + thread.daemon = True + thread.start() + + def _search_thread(self, query: str, search_type: str, sort: str): + """Thread function for performing search.""" + try: + if search_type == "Models": + results = self.api.search_models(query, limit=50, sort=sort) + else: + results = self.api.search_datasets(query, limit=50, sort=sort) + + self.current_results = results + + # Update UI in main thread + self.root.after(0, self._populate_results, results, search_type) + + except Exception as e: + self.root.after(0, lambda: messagebox.showerror("Search Error", str(e))) + finally: + self.root.after(0, lambda: self.search_button.config(state="normal")) + + def _populate_results(self, results: List[Dict], search_type: str): + """Populate the treeview with search results.""" + count = 0 + + for item in results: + try: + # Extract common fields + if search_type == "Models": + repo_id = item.get("modelId", item.get("id", "")) + pipeline_tag = item.get("pipeline_tag", "") + tags = item.get("tags", []) + keywords = ", ".join(tags[:3]) if tags else pipeline_tag + else: + repo_id = item.get("id", "") + task_ids = item.get("cardData", {}).get("task_ids", []) + keywords = ", ".join(task_ids[:3]) if task_ids else "dataset" + + creator = repo_id.split("/")[0] if "/" in repo_id else "" + name = repo_id.split("/")[1] if "/" in repo_id else repo_id + + # Get description + description = "" + if search_type == "Models": + description = item.get("description", "") + else: + card_data = item.get("cardData", {}) + description = card_data.get("description", card_data.get("summary", "")) + + # Truncate description + if len(description) > 100: + description = description[:97] + "..." + + # Calculate size + size_bytes = 0 + siblings = item.get("siblings", []) + for sibling in siblings: + if isinstance(sibling, dict): + size = sibling.get("size", 0) + if isinstance(size, (int, float)): + size_bytes += size + + size_str = self._format_size(size_bytes) if size_bytes > 0 else "-" + + # Get metadata + metadata_parts = [] + downloads = item.get("downloads", 0) + likes = item.get("likes", 0) + + if downloads > 0: + metadata_parts.append(f"↓{self._format_number(downloads)}") + if likes > 0: + metadata_parts.append(f"♥{self._format_number(likes)}") + + if search_type == "Models": + library = item.get("library_name", "") + if library: + metadata_parts.append(library) + + metadata = " | ".join(metadata_parts) + + # Insert into treeview + self.results_tree.insert("", tk.END, values=( + creator, name, description, keywords, size_str, metadata + )) + + count += 1 + + except Exception as e: + print(f"Error processing result: {e}") + continue + + self.status_var.set(f"Found {count} {search_type.lower()}") + + def _format_size(self, bytes_size: int) -> str: + """Format bytes to human readable size.""" + for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + if bytes_size < 1024.0: + return f"{bytes_size:.1f} {unit}" + bytes_size /= 1024.0 + return f"{bytes_size:.1f} PB" + + def _format_number(self, num: int) -> str: + """Format large numbers with K, M suffixes.""" + if num >= 1_000_000: + return f"{num/1_000_000:.1f}M" + elif num >= 1_000: + return f"{num/1_000:.1f}K" + return str(num) + + def _on_item_double_click(self, event): + """Handle double-click on a result item.""" + self._download_selected() + + def _download_selected(self): + """Download the selected model or dataset.""" + selection = self.results_tree.selection() + if not selection: + messagebox.showinfo("No Selection", "Please select an item to download") + return + + item = self.results_tree.item(selection[0]) + values = item['values'] + + if len(values) < 2: + return + + creator = values[0] + name = values[1] + repo_id = f"{creator}/{name}" if creator else name + + # Ask for download location + download_dir = filedialog.askdirectory(title="Select Download Directory") + if not download_dir: + return + + # Create download window + self._show_download_window(repo_id, download_dir) + + def _show_download_window(self, repo_id: str, download_dir: str): + """Show a window for selecting files to download.""" + download_window = tk.Toplevel(self.root) + download_window.title(f"Download: {repo_id}") + download_window.geometry("800x500") + + # Get files list + ttk.Label(download_window, text="Fetching file list...").pack(pady=10) + + def fetch_files(): + files = self.api.get_model_files(repo_id) + download_window.after(0, lambda: self._populate_download_window( + download_window, repo_id, download_dir, files)) + + thread = threading.Thread(target=fetch_files) + thread.daemon = True + thread.start() + + def _populate_download_window(self, window: tk.Toplevel, repo_id: str, + download_dir: str, files: List[Dict]): + """Populate the download window with file list.""" + # Clear window + for widget in window.winfo_children(): + widget.destroy() + + ttk.Label(window, text=f"Select files to download from {repo_id}:").pack(pady=5) + + # File list frame + list_frame = ttk.Frame(window) + list_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Create treeview for files + columns = ("filename", "size") + file_tree = ttk.Treeview(list_frame, columns=columns, show="tree headings", height=15) + file_tree.heading("#0", text="Select") + file_tree.heading("filename", text="File") + file_tree.heading("size", text="Size") + + file_tree.column("#0", width=50) + file_tree.column("filename", width=500) + file_tree.column("size", width=100) + + vsb = ttk.Scrollbar(list_frame, orient="vertical", command=file_tree.yview) + file_tree.configure(yscrollcommand=vsb.set) + + file_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + vsb.pack(side=tk.RIGHT, fill=tk.Y) + + # Add files to tree with checkboxes + file_vars = {} + for file_info in files: + filename = file_info.get("rfilename", "") + size = file_info.get("size", 0) + size_str = self._format_size(size) if size > 0 else "-" + + item_id = file_tree.insert("", tk.END, text="☐", + values=(filename, size_str)) + file_vars[item_id] = {"filename": filename, "selected": False} + + # Toggle selection on click + def toggle_selection(event): + item = file_tree.identify("item", event.x, event.y) + if item in file_vars: + file_vars[item]["selected"] = not file_vars[item]["selected"] + check = "☑" if file_vars[item]["selected"] else "☐" + file_tree.item(item, text=check) + + file_tree.bind("", toggle_selection) + + # Button frame + button_frame = ttk.Frame(window) + button_frame.pack(fill=tk.X, padx=10, pady=10) + + def select_all(): + for item_id in file_vars: + file_vars[item_id]["selected"] = True + file_tree.item(item_id, text="☑") + + def select_none(): + for item_id in file_vars: + file_vars[item_id]["selected"] = False + file_tree.item(item_id, text="☐") + + def select_gguf(): + for item_id in file_vars: + filename = file_vars[item_id]["filename"] + is_gguf = filename.lower().endswith(".gguf") + file_vars[item_id]["selected"] = is_gguf + file_tree.item(item_id, text="☑" if is_gguf else "☐") + + ttk.Button(button_frame, text="Select All", command=select_all).pack(side=tk.LEFT, padx=5) + ttk.Button(button_frame, text="Select None", command=select_none).pack(side=tk.LEFT, padx=5) + ttk.Button(button_frame, text="Select GGUF Only", command=select_gguf).pack(side=tk.LEFT, padx=5) + + def start_download(): + selected_files = [info["filename"] for info in file_vars.values() if info["selected"]] + if not selected_files: + messagebox.showinfo("No Selection", "Please select at least one file to download") + return + + window.destroy() + self._download_files(repo_id, selected_files, download_dir) + + ttk.Button(button_frame, text="Download Selected", + command=start_download).pack(side=tk.RIGHT, padx=5) + ttk.Button(button_frame, text="Cancel", + command=window.destroy).pack(side=tk.RIGHT, padx=5) + + def _download_files(self, repo_id: str, files: List[str], download_dir: str): + """Download selected files.""" + # Create progress window + progress_window = tk.Toplevel(self.root) + progress_window.title("Downloading...") + progress_window.geometry("500x200") + + ttk.Label(progress_window, text=f"Downloading from {repo_id}").pack(pady=10) + + current_file_var = tk.StringVar(value="Preparing...") + ttk.Label(progress_window, textvariable=current_file_var).pack(pady=5) + + progress_var = tk.DoubleVar() + progress_bar = ttk.Progressbar(progress_window, variable=progress_var, + maximum=100, length=400) + progress_bar.pack(pady=10) + + status_var = tk.StringVar(value="Starting download...") + ttk.Label(progress_window, textvariable=status_var).pack(pady=5) + + cancel_flag = {"cancelled": False} + + def cancel_download(): + cancel_flag["cancelled"] = True + progress_window.destroy() + + ttk.Button(progress_window, text="Cancel", command=cancel_download).pack(pady=10) + + def download_thread(): + total_files = len(files) + completed = 0 + + for filename in files: + if cancel_flag["cancelled"]: + break + + current_file_var.set(f"Downloading: {filename}") + save_path = os.path.join(download_dir, repo_id.replace("/", "_"), filename) + + def update_progress(percent, downloaded, total): + progress_var.set(percent) + size_str = f"{self._format_size(downloaded)} / {self._format_size(total)}" + status_var.set(f"File {completed + 1}/{total_files}: {size_str}") + + success = self.api.download_file(repo_id, filename, save_path, update_progress) + + if success: + completed += 1 + + if cancel_flag["cancelled"]: + break + + if not cancel_flag["cancelled"]: + progress_window.after(0, lambda: messagebox.showinfo( + "Download Complete", + f"Downloaded {completed}/{total_files} files to {download_dir}")) + + progress_window.after(0, progress_window.destroy) + + thread = threading.Thread(target=download_thread) + thread.daemon = True + thread.start() + + +def main(): + """Main entry point for the HuggingFace Downloader GUI.""" + root = tk.Tk() + app = HuggingFaceDownloaderGUI(root) + root.mainloop() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/install_mcp.py b/install_mcp.py new file mode 100644 index 0000000..20fd86e --- /dev/null +++ b/install_mcp.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +""" +Install MCP server dependencies + +This script installs the required packages for the MCP server functionality. +""" + +import subprocess +import sys +import os +from pathlib import Path + + +def install_mcp_dependencies(): + """Install MCP server dependencies.""" + print("Installing MCP server dependencies...") + + # Check if we're in a virtual environment + venv_path = None + if hasattr(sys, 'prefix') and sys.prefix != sys.base_prefix: + # We're in a virtual environment + if sys.platform == "win32": + venv_path = Path(sys.prefix) / "Scripts" / "pip.exe" + else: + venv_path = Path(sys.prefix) / "bin" / "pip" + + # Use appropriate pip command + if venv_path and venv_path.exists(): + pip_cmd = [str(venv_path)] + else: + pip_cmd = [sys.executable, "-m", "pip"] + + # Install packages + packages = [ + "mcp>=1.0.0", + "uvloop>=0.19.0; sys_platform != 'win32'" + ] + + try: + for package in packages: + print(f"Installing {package}...") + result = subprocess.run( + pip_cmd + ["install", package], + capture_output=True, + text=True, + check=True + ) + print(f"✓ {package} installed successfully") + + print("\n✅ All MCP dependencies installed successfully!") + print("\nTo use the MCP server:") + print("1. Configure it via Tools → MCP Server Config") + print("2. Run: python mcp_server.py") + print("3. Add the configuration to Claude Desktop") + + return True + + except subprocess.CalledProcessError as e: + print(f"❌ Error installing packages: {e}") + print(f"Command output: {e.stdout}") + print(f"Command error: {e.stderr}") + return False + except Exception as e: + print(f"❌ Unexpected error: {e}") + return False + + +def check_mcp_installation(): + """Check if MCP is already installed.""" + try: + import mcp + print(f"✓ MCP is already installed (version: {mcp.__version__})") + return True + except ImportError: + print("ℹ MCP is not installed") + return False + + +def main(): + """Main installation function.""" + print("=== LLM_Train MCP Server Setup ===\n") + + if check_mcp_installation(): + response = input("MCP is already installed. Reinstall? (y/N): ") + if response.lower() != 'y': + print("Setup cancelled.") + return + + success = install_mcp_dependencies() + + if success: + print("\n🎉 Setup complete! You can now use the MCP server functionality.") + else: + print("\n💥 Setup failed. Please check the error messages above.") + print("You may need to run this script as administrator or check your internet connection.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/launch_remote_control.bat b/launch_remote_control.bat new file mode 100644 index 0000000..7304517 --- /dev/null +++ b/launch_remote_control.bat @@ -0,0 +1,23 @@ +@echo off +:: Launch LLM_Train Remote Control Application + +echo Starting LLM_Train Remote Control... +echo. + +:: Check if Python is available +python --version >nul 2>&1 +if %errorlevel% neq 0 ( + echo Error: Python is not installed or not in PATH. + echo Please install Python from https://python.org/downloads/ + pause + exit /b 1 +) + +:: Launch the remote control application +python remotecontrol.py + +if %errorlevel% neq 0 ( + echo. + echo Application exited with error code %errorlevel% + pause +) \ No newline at end of file diff --git a/llm_runtime/__init__.py b/llm_runtime/__init__.py new file mode 100644 index 0000000..97c88c3 --- /dev/null +++ b/llm_runtime/__init__.py @@ -0,0 +1,37 @@ +from .registry import load_model as _load_model +from .types import UnifiedModel, GenerateConfig + +# Announce the loaded model to the global spy so other parts of the app can access it +def load_model(*args, **kwargs): + """ + Proxy to the real registry.load_model that also announces the loaded model via __spy. + Accepts arbitrary args/kwargs to remain compatible with all loaders. + """ + model = _load_model(*args, **kwargs) + try: + import __spy as spy # local import to avoid hard dependency during tooling + # Best-effort extraction of a model name from common argument patterns + model_name = ( + kwargs.get("source") + or kwargs.get("model") + or (args[0] if args else None) + or getattr(model, "name", None) + or "unknown" + ) + + # Shallow capture of load parameters (omit non-serializable) + safe_params = {} + for k, v in kwargs.items(): + try: + repr(v) # ensure it is representable + safe_params[k] = v + except Exception: + continue + + spy.set_model(str(model_name), model, **safe_params) + except Exception: + # Never let announcing break model loading + pass + return model + +__all__ = ["load_model", "UnifiedModel", "GenerateConfig"] \ No newline at end of file diff --git a/llm_runtime/chat_session.py b/llm_runtime/chat_session.py new file mode 100644 index 0000000..ab931ea --- /dev/null +++ b/llm_runtime/chat_session.py @@ -0,0 +1,268 @@ +""" +Chat Session Management for KV Caching + +This module provides chat session management with persistent KV cache +for efficient multi-turn conversations across different model formats. +""" + +from typing import Any, Dict, List, Optional, Tuple, Union +from dataclasses import dataclass +import threading + + +@dataclass +class ChatSessionConfig: + """Configuration for chat sessions""" + max_context_length: int = 4096 + cache_warmup: bool = True # Pre-fill system prompt + streaming: bool = True + + +class ChatSession: + """ + Manages persistent KV cache state for multi-turn conversations. + + Supports different model formats: + - GGUF models: Rely on llama-cpp-python's built-in caching + - Transformers models: Manual KV cache management with past_key_values + """ + + def __init__(self, session_id: str, config: ChatSessionConfig = None): + self.session_id = session_id + self.config = config or ChatSessionConfig() + self.lock = threading.RLock() + + # KV cache state (transformers models) + self.past_key_values: Optional[Tuple] = None + self.cached_input_ids: Optional[List[int]] = None + self.context_length: int = 0 + + # Conversation history + self.messages: List[Dict[str, str]] = [] + self.system_prompt: Optional[str] = None + + # State tracking + self.is_prefilled: bool = False + self.last_prompt_hash: Optional[str] = None + self.last_prompt: Optional[str] = None + + def add_message(self, role: str, content: str) -> None: + """Add a message to the conversation history""" + with self.lock: + if role == "system" and not self.messages: + self.system_prompt = content + self.messages.append({"role": role, "content": content}) + + def clear_messages(self) -> None: + """Clear conversation history but preserve system prompt""" + with self.lock: + if self.system_prompt: + self.messages = [{"role": "system", "content": self.system_prompt}] + else: + self.messages = [] + self.invalidate_cache() + + def invalidate_cache(self) -> None: + """Invalidate KV cache - forces re-prefill on next generation""" + with self.lock: + self.past_key_values = None + self.cached_input_ids = None + self.context_length = 0 + self.is_prefilled = False + self.last_prompt_hash = None + self.last_prompt = None + + def should_invalidate(self, new_prompt: str, tokenizer=None) -> bool: + """Check if cache should be invalidated based on prompt changes with proper token-level validation""" + import hashlib + + current_hash = hashlib.md5(new_prompt.encode()).hexdigest() + + # Always invalidate if no cache exists + if self.past_key_values is None: + return True + + # Always invalidate if we don't have a tokenizer for proper validation + if tokenizer is None: + return True + + # Always invalidate if no previous prompt exists + if not hasattr(self, 'last_prompt') or self.last_prompt is None: + return True + + # Check for exact token-level prefix match + if self._has_valid_token_prefix(new_prompt, self.last_prompt, tokenizer): + # Additional safety checks for turn boundaries + if self._violates_turn_boundaries(new_prompt, self.last_prompt): + print("[KV_CACHE] Invalidating cache: turn boundary violation detected") + return True + return False + + # Invalidate if no valid prefix match + print("[KV_CACHE] Invalidating cache: no valid token prefix match") + return True + + def _has_valid_token_prefix(self, new_prompt: str, old_prompt: str, tokenizer) -> bool: + """Check if new prompt has exact token-level prefix match with cached prompt""" + try: + # Tokenize both prompts + old_tokens = tokenizer.encode(old_prompt, add_special_tokens=False) + new_tokens = tokenizer.encode(new_prompt, add_special_tokens=False) + + # New prompt must be longer than or equal to old + if len(new_tokens) < len(old_tokens): + return False + + # Check exact token-by-token match for the prefix + for i, (old_token, new_token) in enumerate(zip(old_tokens, new_tokens)): + if old_token != new_token: + print(f"[KV_CACHE] Token mismatch at position {i}: old={old_token}, new={new_token}") + return False + + return True + except Exception as e: + print(f"[KV_CACHE] Error in token prefix validation: {e}") + return False + + def _violates_turn_boundaries(self, new_prompt: str, old_prompt: str) -> bool: + """Check if the prompt violates turn boundaries (conversation integrity)""" + # Look for end-of-turn markers that indicate conversation corruption + eot_markers = ["<|eot_id|>", "<|end_of_turn|>", "", "<|endoftext|>"] + + # If the old prompt ended with an EOT marker, we should start fresh + for marker in eot_markers: + if old_prompt.rstrip().endswith(marker): + # Check if new prompt is a proper continuation (should start with User: or similar) + continuation_part = new_prompt[len(old_prompt):].lstrip() + if not (continuation_part.startswith("User:") or continuation_part.startswith("Human:") or continuation_part.startswith("\nUser:") or continuation_part.startswith("\nHuman:")): + return True + + # Check for conversation format corruption (duplicate roles, malformed structure) + if self._has_conversation_corruption(new_prompt): + return True + + return False + + def _has_conversation_corruption(self, prompt: str) -> bool: + """Detect conversation format corruption that indicates cache should be invalidated""" + # Look for signs of corrupted conversation format + corruption_patterns = [ + "Assistant: Assistant:", # Duplicate assistant labels + "User: User:", # Duplicate user labels + "Assistant: User", # Role confusion + "User: Assistant:", # Role confusion + "of of of", # Repetitive token generation (sign of corruption) + "151 of 151", # Specific corruption pattern we observed + ] + + for pattern in corruption_patterns: + if pattern in prompt: + print(f"[KV_CACHE] Detected conversation corruption: '{pattern}'") + return True + + return False + + def update_cache(self, past_key_values: Tuple, input_ids: List[int], prompt: str) -> None: + """Update the KV cache state with turn boundary validation""" + import hashlib + + with self.lock: + # Validate that we're ending on a complete turn boundary + if not self._is_complete_turn(prompt): + print("[KV_CACHE] Warning: Caching incomplete turn - this may cause issues") + + self.past_key_values = past_key_values + self.cached_input_ids = input_ids + self.context_length = len(input_ids) + self.is_prefilled = True + self.last_prompt_hash = hashlib.md5(prompt.encode()).hexdigest() + self.last_prompt = prompt # Store the actual prompt for comparison + + def _is_complete_turn(self, prompt: str) -> bool: + """Check if the prompt ends on a complete turn boundary""" + # Look for proper turn endings + prompt_trimmed = prompt.rstrip() + + # Should end with Assistant response, not mid-generation + valid_endings = [ + "", + "<|eot_id|>", + "<|end_of_turn|>", + "<|endoftext|>", + ] + + # Or should end with a complete sentence/response + if any(prompt_trimmed.endswith(ending) for ending in valid_endings): + return True + + # For non-marked conversations, check if it looks like a complete response + # (ends with punctuation and doesn't look cut off) + if prompt_trimmed.endswith(('.', '!', '?', ':', ';')): + return True + + # If it ends with "Assistant:" it's ready for generation + if prompt_trimmed.endswith("Assistant:"): + return True + + return False + + def get_cache_info(self) -> Dict[str, Any]: + """Get information about current cache state""" + with self.lock: + return { + "session_id": self.session_id, + "has_cache": self.past_key_values is not None, + "context_length": self.context_length, + "is_prefilled": self.is_prefilled, + "message_count": len(self.messages), + "max_context": self.config.max_context_length + } + + +class ChatSessionManager: + """Manages multiple chat sessions""" + + def __init__(self): + self.sessions: Dict[str, ChatSession] = {} + self.lock = threading.RLock() + self.default_session_id = "default" + + def get_session(self, session_id: str = None, config: ChatSessionConfig = None) -> ChatSession: + """Get or create a chat session""" + if session_id is None: + session_id = self.default_session_id + + with self.lock: + if session_id not in self.sessions: + self.sessions[session_id] = ChatSession(session_id, config) + return self.sessions[session_id] + + def clear_session(self, session_id: str = None) -> None: + """Clear a specific session""" + if session_id is None: + session_id = self.default_session_id + + with self.lock: + if session_id in self.sessions: + self.sessions[session_id].invalidate_cache() + self.sessions[session_id].clear_messages() + + def remove_session(self, session_id: str) -> None: + """Remove a session completely""" + with self.lock: + if session_id in self.sessions: + del self.sessions[session_id] + + def get_all_sessions(self) -> List[str]: + """Get list of all session IDs""" + with self.lock: + return list(self.sessions.keys()) + + +# Global session manager instance +_session_manager = ChatSessionManager() + + +def get_session_manager() -> ChatSessionManager: + """Get the global chat session manager""" + return _session_manager \ No newline at end of file diff --git a/llm_runtime/device_utils.py b/llm_runtime/device_utils.py new file mode 100644 index 0000000..e081fc8 --- /dev/null +++ b/llm_runtime/device_utils.py @@ -0,0 +1,99 @@ +import torch +from typing import Union, Literal + +Backend = Literal["hf", "gptq"] +DevIn = Union[None, str, int] +DevOut = Union[str, int] + +def _has_mps() -> bool: + return getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available() + +def _first_cuda_index() -> int | None: + return 0 if torch.cuda.is_available() and torch.cuda.device_count() > 0 else None + +def normalize_device(dev: DevIn = None, *, backend: Backend = "hf") -> DevOut: + """ + Normalize a user/device string/int into what each backend expects. + + Inputs accepted: + None | "auto" | "cpu" | "mps" | "disk" | "cuda" | "cuda:N" | N (int) + + Returns: + backend == "hf": "cpu" | "mps" | "cuda:N" + backend == "gptq": "cpu" | "mps" | "disk" | N (int) + """ + # 1) Auto/default + if dev in (None, "auto"): + cuda0 = _first_cuda_index() + if cuda0 is not None: + return (cuda0 if backend == "gptq" else f"cuda:{cuda0}") + if _has_mps(): + return "mps" + # AutoGPTQ can also run with 'disk' offload if caller wants; default to CPU here + return "cpu" + + # 2) Explicit CPU/MPS/DISK + if isinstance(dev, str) and dev.lower() in {"cpu", "mps", "disk"}: + # HF does not know "disk"; treat as CPU for HF branch + return dev if backend == "gptq" or dev != "disk" else "cpu" + + # 3) Explicit CUDA string + if isinstance(dev, str) and dev.lower().startswith("cuda"): + # Accept "cuda" and "cuda:N" + if dev == "cuda": + idx = _first_cuda_index() + if idx is None: + # No CUDA available; degrade to CPU/MPS appropriately + return "cpu" if backend == "hf" else "cpu" + return (idx if backend == "gptq" else f"cuda:{idx}") + # cuda:N + try: + idx = int(dev.split(":", 1)[1]) + except (IndexError, ValueError): + raise ValueError(f"Bad CUDA device string: {dev!r}. Use 'cuda' or 'cuda:N'.") + return (idx if backend == "gptq" else f"cuda:{idx}") + + # 4) Integer GPU index + if isinstance(dev, int): + if dev < 0: + raise ValueError(f"GPU index must be >= 0, got {dev}") + return (dev if backend == "gptq" else f"cuda:{dev}") + + raise ValueError(f"Unsupported device spec for backend={backend!r}: {dev!r}") + +# --- Convenience wrappers ------------------------------------------------------- + +def device_for_hf(dev: DevIn = None) -> str: + """Return a device string suitable for HuggingFace (e.g., 'cuda:0', 'cpu', 'mps').""" + out = normalize_device(dev, backend="hf") + assert isinstance(out, str) + return out + +def device_for_gptq(dev: DevIn = None) -> Union[int, str]: + """Return an int GPU index or 'cpu'/'mps'/'disk' for AutoGPTQ.""" + out = normalize_device(dev, backend="gptq") + assert isinstance(out, (int, str)) + return out + +def debug_device_placement(model, name="model"): + """Debug helper to check where model parameters are placed""" + try: + devices = set() + for name_param, param in model.named_parameters(): + devices.add(str(param.device)) + print(f"[DEBUG] {name} parameters on devices: {devices}") + + # Check first parameter device + first_param = next(model.parameters()) + print(f"[DEBUG] {name} primary device: {first_param.device}") + + return first_param.device + except Exception as e: + print(f"[DEBUG] Could not check {name} device placement: {e}") + return None + +# --- Minimal self-test (run this file directly) --------------------------------- +if __name__ == "__main__": + tests = [None, "auto", "cpu", "mps", "disk", "cuda", "cuda:0", "cuda:1", 0, 1] + for t in tests: + print(f"in={t!r:7} -> hf={device_for_hf(t)!r:7} gptq={device_for_gptq(t)!r}") \ No newline at end of file diff --git a/llm_runtime/loader_factory.py b/llm_runtime/loader_factory.py new file mode 100644 index 0000000..facf415 --- /dev/null +++ b/llm_runtime/loader_factory.py @@ -0,0 +1,25 @@ +# loader_factory.py +from typing import Any, Tuple +from .model_router import detect_loader_type, LoaderKind + +def select_loader(source: str) -> Tuple[LoaderKind, str]: + return detect_loader_type(source) + +def load_model_for_gui(source: str, **kwargs: Any): + kind, reason = detect_loader_type(source) + print(f"[ROUTER] Using {kind.upper()} loader: {reason}") + print(f"[ROUTER] Source: {source}") + print(f"[ROUTER] Kwargs: {kwargs}") + + if kind == "hf": + from .loaders.transformers_loader import HFTransformersLoader + loader = HFTransformersLoader() + print(f"[ROUTER] Using HF loader for: {source}") + else: # "gguf" + from .loaders.llamacpp_loader import LlamaCppLoader + loader = LlamaCppLoader() + print(f"[ROUTER] Using GGUF loader for: {source}") + + # Load the model + model = loader.load(source, **kwargs) + return model, kind, reason \ No newline at end of file diff --git a/llm_runtime/loaders/autogptq_loader.py b/llm_runtime/loaders/autogptq_loader.py new file mode 100644 index 0000000..4b922a4 --- /dev/null +++ b/llm_runtime/loaders/autogptq_loader.py @@ -0,0 +1,176 @@ +from typing import Any, Iterator, List, Optional +import os, json +import torch +from llm_runtime.types import UnifiedModel, GenerateConfig +from llm_runtime.device_utils import device_for_gptq + +def _inputs_device_from_gptq_device(dev) -> torch.device: + """ + AutoGPTQ wants: int GPU index | 'cpu' | 'mps' | 'disk' + But tokenizer tensors need a torch.device: + - int -> 'cuda:{idx}' + - 'cpu' -> 'cpu' + - 'mps' -> 'mps' + - 'disk' -> still run the forward on CUDA/CPU; safest default: 'cpu' + """ + if isinstance(dev, int): + return torch.device(f"cuda:{dev}") + if isinstance(dev, str): + if dev in ("cpu", "mps"): + return torch.device(dev) + if dev == "disk": + # inputs live on CPU; model will page as needed + return torch.device("cpu") + # Fallback + return torch.device("cpu") + +class _GPTQUnified: + def __init__(self, src: str, **kwargs: Any): + print(f"[GPTQ_DEBUG] _GPTQUnified.__init__() called with src='{src}', kwargs={kwargs}") + try: + from auto_gptq import AutoGPTQForCausalLM + from transformers import AutoTokenizer + print("[GPTQ_DEBUG] Successfully imported auto_gptq and transformers") + except ImportError as e: + print(f"[GPTQ_DEBUG] Failed to import auto_gptq or transformers: {e}") + raise ImportError("auto-gptq and transformers are required for GPTQ models. Install with: pip install auto-gptq transformers") + + print(f"[GPTQ_DEBUG] Loading GPTQ model from: {src}") + + # Normalize device for AutoGPTQ (int or 'cpu'/'mps'/'disk') + raw_device = kwargs.get("device") + print(f"[GPTQ_DEBUG] Raw device from kwargs: {raw_device}") + self._gptq_dev = device_for_gptq(raw_device) + print(f"[GPTQ_DEBUG] Normalized device for GPTQ: {self._gptq_dev}") + self._inputs_device = _inputs_device_from_gptq_device(self._gptq_dev) + print(f"[GPTQ_DEBUG] Using device (gptq): {self._gptq_dev} | inputs will go to: {self._inputs_device}") + + trust_remote = kwargs.get("trust_remote_code", True) + token = kwargs.get("token") + + # Tokenizer + self.tok = AutoTokenizer.from_pretrained( + src, + use_fast=True, + trust_remote_code=trust_remote, + token=token + ) + + # Model + self.model = AutoGPTQForCausalLM.from_quantized( + src, + device=self._gptq_dev, # int or 'cpu'/'mps'/'disk' + trust_remote_code=trust_remote, + use_safetensors=True, + use_triton=kwargs.get("use_triton", False), + token=token + ) + + # Pad token safety + if self.tok.pad_token is None: + self.tok.pad_token = self.tok.eos_token + + def _build_eos_ids(self, stop) -> Optional[List[int]]: + """Encode stop strings to token IDs (take the first token of each stop string).""" + if not stop: + return None + out: List[int] = [] + for s in stop: + if not s: + continue + ids = self.tok.encode(s, add_special_tokens=False) + if ids: + out.append(ids[0]) + return out or None + + def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str: + from transformers import StoppingCriteria, StoppingCriteriaList + + enc = self.tok(prompt, return_tensors="pt").to(self._inputs_device) + + class MultiStringStop(StoppingCriteria): + def __init__(self, toks, stops): + self.toks, self.stops = toks, stops or [] + def __call__(self, input_ids, scores, **_): + # Simple but effective; for high perf, implement a token-level matcher. + text = self.toks.decode(input_ids[0], skip_special_tokens=True) + return any(s in text for s in self.stops) + + do_sample = cfg.temperature is not None and cfg.temperature > 0.0 + gen_kwargs = dict( + max_new_tokens=cfg.max_tokens, + do_sample=do_sample, + top_p=cfg.top_p, + eos_token_id=self._build_eos_ids(cfg.stop), + pad_token_id=self.tok.pad_token_id, + ) + if do_sample: + gen_kwargs["temperature"] = float(cfg.temperature) + + if cfg.stop: + gen_kwargs["stopping_criteria"] = StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) + + out_ids = self.model.generate(**enc, **gen_kwargs) + + # Decode only new tokens + new_tokens = out_ids[0][enc.input_ids.shape[1]:] + return self.tok.decode(new_tokens, skip_special_tokens=True) + + def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]: + import threading + from transformers import TextIteratorStreamer + + enc = self.tok(prompt, return_tensors="pt").to(self._inputs_device) + streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True) + + do_sample = cfg.temperature is not None and cfg.temperature > 0.0 + gen_kwargs = dict( + max_new_tokens=cfg.max_tokens, + do_sample=do_sample, + top_p=cfg.top_p, + eos_token_id=self._build_eos_ids(cfg.stop), + streamer=streamer, + pad_token_id=self.tok.pad_token_id, + ) + if do_sample: + gen_kwargs["temperature"] = float(cfg.temperature) + + def _worker(): + self.model.generate(**enc, **gen_kwargs) + + t = threading.Thread(target=_worker, daemon=True) + t.start() + for chunk in streamer: + yield chunk + + def tokenize(self, text: str) -> List[int]: + return self.tok.encode(text, add_special_tokens=False) + + def detokenize(self, ids: List[int]) -> str: + return self.tok.decode(ids, skip_special_tokens=True) + +class AutoGPTQLoader: + name = "gptq" + + def can_load(self, source: str, **kwargs: Any) -> bool: + # Local folder? + if os.path.isdir(source): + # quantize_config.json is a strong GPTQ signal + qc = os.path.join(source, "quantize_config.json") + if os.path.exists(qc): + return True + # Fallback: peek at config.json + try: + with open(os.path.join(source, "config.json"), "r", encoding="utf-8") as f: + cfg = json.load(f) + text = json.dumps(cfg).lower() + return ("gptq" in text) or ("quantize_config" in text) or ("quant_method" in text) + except Exception: + pass + # HF repo path (heuristic): let the loader try + elif "/" in source and not os.path.exists(source): + return True + return False + + def load(self, source: str, **kwargs: Any) -> UnifiedModel: + return _GPTQUnified(source, **kwargs) \ No newline at end of file diff --git a/llm_runtime/loaders/awq_loader.py b/llm_runtime/loaders/awq_loader.py new file mode 100644 index 0000000..19d3505 --- /dev/null +++ b/llm_runtime/loaders/awq_loader.py @@ -0,0 +1,132 @@ +from typing import Any, Iterator, List, Optional +import os +from llm_runtime.types import UnifiedModel, GenerateConfig + +class _AWQUnified: + def __init__(self, src: str, **kwargs: Any): + try: + from autoawq import AutoAWQForCausalLM + from transformers import AutoTokenizer + except ImportError: + raise ImportError("autoawq and transformers are required for AWQ models. Install with: pip install autoawq transformers") + + print(f"Loading AWQ model from: {src}") + + # Load tokenizer + self.tok = AutoTokenizer.from_pretrained( + src, + use_fast=True, + trust_remote_code=kwargs.get("trust_remote_code", True) + ) + + # Load AWQ model + self.model = AutoAWQForCausalLM.from_quantized( + src, + device_map=kwargs.get("device_map", "auto"), + trust_remote_code=kwargs.get("trust_remote_code", True), + safetensors=True, + ) + + # Set pad token if not present + if self.tok.pad_token is None: + self.tok.pad_token = self.tok.eos_token + + def _build_eos_ids(self, stop) -> Optional[List[int]]: + """Convert stop strings to token IDs""" + if not stop: + return None + + ids = [] + for s in stop: + if len(s) == 1: + tid = self.tok.convert_tokens_to_ids(s) + if tid is not None: + ids.append(tid) + return ids or None + + def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str: + from transformers import StoppingCriteria, StoppingCriteriaList + + inputs = self.tok(prompt, return_tensors="pt").to(self.model.device) + + class MultiStringStop(StoppingCriteria): + def __init__(self, toks, stops): + self.toks, self.stops = toks, stops or [] + + def __call__(self, input_ids, scores, **_): + text = self.toks.decode(input_ids[0], skip_special_tokens=True) + return any(s in text for s in self.stops) + + out_ids = self.model.generate( + **inputs, + max_new_tokens=cfg.max_tokens, + do_sample=cfg.temperature > 0.0, + temperature=cfg.temperature if cfg.temperature > 0.0 else None, + top_p=cfg.top_p, + eos_token_id=self._build_eos_ids(cfg.stop), + stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None, + pad_token_id=self.tok.pad_token_id, + ) + + # Decode only the new tokens + generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + return generated_text + + def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]: + import threading + from transformers import TextIteratorStreamer + + enc = self.tok(prompt, return_tensors="pt").to(self.model.device) + streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True) + + def _worker(): + self.model.generate( + **enc, + max_new_tokens=cfg.max_tokens, + do_sample=cfg.temperature > 0.0, + temperature=cfg.temperature if cfg.temperature > 0.0 else None, + top_p=cfg.top_p, + eos_token_id=self._build_eos_ids(cfg.stop), + streamer=streamer, + pad_token_id=self.tok.pad_token_id, + ) + + t = threading.Thread(target=_worker) + t.start() + + for text in streamer: + yield text + + def tokenize(self, text: str) -> List[int]: + return self.tok.encode(text, add_special_tokens=False) + + def detokenize(self, ids: List[int]) -> str: + return self.tok.decode(ids, skip_special_tokens=True) + +class AWQLoader: + name = "awq" + + def can_load(self, source: str, **kwargs: Any) -> bool: + # Check for AWQ indicators + if os.path.isdir(source): + # Look for AWQ indicators in config.json + try: + import json + config_path = os.path.join(source, "config.json") + if os.path.exists(config_path): + with open(config_path, 'r') as f: + config = json.load(f) + config_str = str(config).lower() + return ("awq" in config_str or + "quantization_config" in config and + "awq" in str(config.get("quantization_config", {})).lower()) + except: + pass + elif "/" in source and not os.path.exists(source): # HF repo + # For HF repos, we'll let it try if it looks like AWQ + return "awq" in source.lower() + + return False + + def load(self, source: str, **kwargs: Any) -> UnifiedModel: + return _AWQUnified(source, **kwargs) \ No newline at end of file diff --git a/llm_runtime/loaders/exl2_loader.py b/llm_runtime/loaders/exl2_loader.py new file mode 100644 index 0000000..1a013ac --- /dev/null +++ b/llm_runtime/loaders/exl2_loader.py @@ -0,0 +1,137 @@ +from typing import Any, Iterator, List, Optional +import os +from llm_runtime.types import UnifiedModel, GenerateConfig + +class _ExLlama2Unified: + def __init__(self, src: str, **kwargs: Any): + try: + from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer, ExLlamaV2Cache + from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler + except ImportError: + raise ImportError("exllamav2 is required for EXL2 models. Install with: pip install exllamav2") + + print(f"Loading EXL2 model from: {src}") + + # Configure model + self.config = ExLlamaV2Config(src) + + # Apply any config overrides + if "max_seq_len" in kwargs: + self.config.max_seq_len = kwargs["max_seq_len"] + if "scale_pos_emb" in kwargs: + self.config.scale_pos_emb = kwargs["scale_pos_emb"] + if "scale_alpha_value" in kwargs: + self.config.scale_alpha_value = kwargs["scale_alpha_value"] + + # Initialize model + self.model = ExLlamaV2(self.config) + + # Load model weights + self.model.load() + + # Initialize tokenizer + self.tokenizer = ExLlamaV2Tokenizer(self.config) + + # Initialize cache + self.cache = ExLlamaV2Cache(self.model, lazy=kwargs.get("lazy_cache", True)) + + # Initialize generator for streaming + self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer) + + print(f"EXL2 model loaded successfully") + + def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str: + # Create sampler settings + settings = ExLlamaV2Sampler.Settings() + settings.temperature = cfg.temperature + settings.top_p = cfg.top_p + + # Set stop conditions + if cfg.stop: + # ExLlamaV2 expects stop strings as a list + stop_conditions = list(cfg.stop) + else: + stop_conditions = [] + + # Generate text + output = self.generator.generate_simple( + prompt=prompt, + max_new_tokens=cfg.max_tokens, + seed=kwargs.get("seed", -1), + token_healing=kwargs.get("token_healing", True), + temperature=cfg.temperature, + top_p=cfg.top_p, + stop_conditions=stop_conditions, + ) + + return output + + def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]: + # Create sampler settings + settings = ExLlamaV2Sampler.Settings() + settings.temperature = cfg.temperature + settings.top_p = cfg.top_p + + # Set stop conditions + if cfg.stop: + stop_conditions = list(cfg.stop) + else: + stop_conditions = [] + + # Begin streaming generation + input_ids = self.tokenizer.encode(prompt) + self.generator.begin_stream( + input_ids=input_ids, + gen_settings=settings, + token_healing=kwargs.get("token_healing", True), + seed=kwargs.get("seed", -1), + ) + + generated_tokens = 0 + + while generated_tokens < cfg.max_tokens: + chunk, eos, tokens = self.generator.stream() + + if chunk: + yield chunk + + generated_tokens += tokens + + # Check for stop conditions + if eos: + break + + if cfg.stop: + # Check if any stop condition is met in the generated text so far + current_text = self.generator.sequence_str + if any(stop in current_text for stop in cfg.stop): + break + + def tokenize(self, text: str) -> List[int]: + return self.tokenizer.encode(text).tolist() + + def detokenize(self, ids: List[int]) -> str: + import torch + tensor_ids = torch.tensor([ids], dtype=torch.long) + return self.tokenizer.decode(tensor_ids)[0] + +class ExLlama2Loader: + name = "exllama2" + + def can_load(self, source: str, **kwargs: Any) -> bool: + # Check for EXL2 model directory structure + if os.path.isdir(source): + # Look for config.json and .safetensors files that indicate EXL2 + config_path = os.path.join(source, "config.json") + if os.path.exists(config_path): + # Check if there are .safetensors files with EXL2 naming pattern + for file in os.listdir(source): + if file.endswith(".safetensors") and ("model" in file.lower() or "exl2" in file.lower()): + return True + elif source.lower().endswith(".exl2"): + return True + + return False + + def load(self, source: str, **kwargs: Any) -> UnifiedModel: + return _ExLlama2Unified(source, **kwargs) \ No newline at end of file diff --git a/llm_runtime/loaders/llamacpp_loader.py b/llm_runtime/loaders/llamacpp_loader.py new file mode 100644 index 0000000..c5e3dd9 --- /dev/null +++ b/llm_runtime/loaders/llamacpp_loader.py @@ -0,0 +1,91 @@ +from typing import Any, Iterator, List +from llm_runtime.types import UnifiedModel, GenerateConfig + +class _LlamaCppUnified: + def __init__(self, model_path: str, **kwargs: Any): + # Import llama_cpp directly instead of from main.py to avoid circular imports + from llama_cpp import Llama + + if not model_path.lower().endswith(".gguf"): + raise ValueError(f"Not a valid GGUF model: {model_path}") + + self.model_path = model_path + self.kwargs = kwargs + self._llama = None + + def _get_model(self): + """Lazy load the model using existing implementation""" + if self._llama is None: + # Import the _get_llama function from main module to maintain compatibility + try: + import sys + import os + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + from main import _get_llama + self._llama = _get_llama( + self.model_path, + n_ctx=self.kwargs.get("n_ctx", 8192), + n_gpu_layers=self.kwargs.get("n_gpu_layers", 0), + lora_path=self.kwargs.get("lora_path"), + n_threads=self.kwargs.get("n_threads"), + ) + except ImportError: + # Fallback to direct llama-cpp-python if main import fails + from llama_cpp import Llama + self._llama = Llama( + model_path=self.model_path, + n_ctx=self.kwargs.get("n_ctx", 8192), + n_gpu_layers=self.kwargs.get("n_gpu_layers", 0), + verbose=self.kwargs.get("verbose", False), + n_threads=self.kwargs.get("n_threads") + ) + return self._llama + + def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str: + llama = self._get_model() + + # Convert GenerateConfig to llama-cpp-python format + result = llama( + prompt, + max_tokens=cfg.max_tokens, + temperature=cfg.temperature, + top_p=cfg.top_p, + stop=list(cfg.stop) if cfg.stop else None, + echo=False + ) + + return result["choices"][0]["text"] + + def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]: + llama = self._get_model() + + # Use streaming generation + for chunk in llama.create_completion( + prompt, + max_tokens=cfg.max_tokens, + temperature=cfg.temperature, + top_p=cfg.top_p, + stop=list(cfg.stop) if cfg.stop else None, + stream=True, + echo=False + ): + text = chunk["choices"][0]["text"] + if text: + yield text + + def tokenize(self, text: str) -> List[int]: + llama = self._get_model() + return llama.tokenize(text.encode("utf-8"), add_bos=False) + + def detokenize(self, ids: List[int]) -> str: + llama = self._get_model() + return llama.detokenize(ids).decode("utf-8", errors="ignore") + +class LlamaCppLoader: + name = "llamacpp" + + def can_load(self, source: str, **kwargs: Any) -> bool: + return source.lower().endswith(".gguf") + + def load(self, source: str, **kwargs: Any) -> UnifiedModel: + return _LlamaCppUnified(source, **kwargs) \ No newline at end of file diff --git a/llm_runtime/loaders/transformers_loader.py b/llm_runtime/loaders/transformers_loader.py new file mode 100644 index 0000000..401a851 --- /dev/null +++ b/llm_runtime/loaders/transformers_loader.py @@ -0,0 +1,672 @@ +from typing import Any, Iterator, List, Optional, Tuple, Dict +import os +import time +from llm_runtime.types import UnifiedModel, GenerateConfig +from llm_runtime.util_chat import apply_chat_template +from llm_runtime.chat_session import ChatSession, get_session_manager +from llm_runtime.device_utils import device_for_hf + +class _HFUnified: + def __init__(self, src: str, **kwargs: Any): + print(f"[HF_DEBUG] _HFUnified.__init__() called with src='{src}', kwargs={kwargs}") + try: + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer + print("[HF_DEBUG] Successfully imported torch and transformers") + except ImportError: + print("[HF_DEBUG] Failed to import torch or transformers") + raise ImportError("transformers, torch, and accelerate are required for HF models. Install with: pip install transformers torch accelerate safetensors") + + self.torch = torch + self.TextIteratorStreamer = TextIteratorStreamer + + # Normalize device for HuggingFace + self.device = device_for_hf(kwargs.get("device")) + print(f"Loading HF model from: {src}") + print(f"Using device: {self.device}") + + # Load tokenizer + self.tok = AutoTokenizer.from_pretrained( + src, + use_fast=True, + trust_remote_code=kwargs.get("trust_remote_code", True), + token=kwargs.get("token") + ) + + # Prepare quantization config if requested + quantization_config = self._prepare_quantization_config(kwargs) + + # Prepare device mapping + device_map, max_memory = self._prepare_device_mapping(kwargs, self.device) + + # Load model with advanced options + load_kwargs = { + "torch_dtype": kwargs.get("torch_dtype", "auto"), + "device_map": device_map, + "trust_remote_code": kwargs.get("trust_remote_code", True), + "low_cpu_mem_usage": True, + "token": kwargs.get("token") + } + + if quantization_config: + load_kwargs["quantization_config"] = quantization_config + print(f"Using quantization: {kwargs.get('quantization', 'none')}") + # When quantization is enabled, force device_map to "auto" to avoid device format conflicts + load_kwargs["device_map"] = "auto" + print(f"[DEBUG] Forcing device_map='auto' for quantization compatibility") + + if max_memory: + load_kwargs["max_memory"] = max_memory + print(f"Memory limits: {max_memory}") + + if kwargs.get("offload_folder"): + load_kwargs["offload_folder"] = kwargs.get("offload_folder") + print(f"Offloading to: {kwargs.get('offload_folder')}") + + self.model = AutoModelForCausalLM.from_pretrained(src, **load_kwargs) + + # Set pad token if not present + if self.tok.pad_token is None: + self.tok.pad_token = self.tok.eos_token + + # Initialize chat session support with context from kwargs + self.session_manager = get_session_manager() + self.current_session = None + + # Store context size for session creation + self.n_ctx = kwargs.get('n_ctx', 4096) + + def _prepare_quantization_config(self, kwargs: Any): + """Prepare quantization configuration""" + quantization = kwargs.get("quantization", "none") + + if quantization == "none": + return None + + try: + from transformers import BitsAndBytesConfig + except ImportError: + print("Warning: BitsAndBytesConfig not available, skipping quantization") + return None + + if quantization == "4bit": + return BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=self.torch.bfloat16, + ) + elif quantization == "8bit": + return BitsAndBytesConfig( + load_in_8bit=True, + ) + else: + print(f"Warning: Unknown quantization type '{quantization}', skipping") + return None + + def _prepare_device_mapping(self, kwargs: Any, hf_device: str): + """Prepare device mapping and memory limits""" + device_strategy = kwargs.get("device_strategy", "auto") + gpu_memory_limit = kwargs.get("gpu_memory_limit", None) + + device_map = "auto" + max_memory = None + + if device_strategy == "force_gpu": + device_map = {"": hf_device} + elif device_strategy == "balanced_split": + # Use memory limits for balanced CPU/GPU split + if gpu_memory_limit: + # For quantization, use integer device index; for non-quantization, use string + if kwargs.get("quantization", "none") != "none": + gpu_device = 0 if hf_device.startswith("cuda:") else 0 + else: + gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0" + max_memory = { + gpu_device: f"{gpu_memory_limit}GiB", + "cpu": "48GiB", # Large CPU limit + } + device_map = "auto" + elif device_strategy == "cpu_only": + device_map = {"": "cpu"} + else: # auto + if gpu_memory_limit: + # For quantization, use integer device index; for non-quantization, use string + if kwargs.get("quantization", "none") != "none": + gpu_device = 0 if hf_device.startswith("cuda:") else 0 + else: + gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0" + max_memory = { + gpu_device: f"{gpu_memory_limit}GiB", + "cpu": "32GiB", + } + device_map = "auto" + + return device_map, max_memory + + def _build_eos_ids(self, stop) -> Optional[List[int]]: + """Convert stop strings to token IDs""" + if not stop: + return None + + ids = [] + for s in stop: + # Handle single character stops + if len(s) == 1: + tid = self.tok.convert_tokens_to_ids(s) + if tid is not None: + ids.append(tid) + return ids or None + + def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str: + """Generate text - uses KV cache if session_id provided, otherwise fallback to standard generation""" + # Set defaults for None values + if cfg.temperature is None: + cfg.temperature = 0.7 # Default temperature + if cfg.top_p is None: + cfg.top_p = 0.95 # Default top_p + if cfg.max_tokens is None: + cfg.max_tokens = 500 # Default max tokens + + session_id = kwargs.get('session_id', 'default') # Use 'default' session if none specified + if session_id: + return self.generate_with_cache(prompt, session_id, cfg, **kwargs) + + # Fallback to standard generation without KV cache + from transformers import StoppingCriteria, StoppingCriteriaList + + inputs = self.tok(prompt, return_tensors="pt").to(self.model.device) + + class MultiStringStop(StoppingCriteria): + def __init__(self, toks, stops): + self.toks, self.stops = toks, stops or [] + + def __call__(self, input_ids, scores, **_): + text = self.toks.decode(input_ids[0], skip_special_tokens=True) + return any(s in text for s in self.stops) + + with self.torch.no_grad(): + out_ids = self.model.generate( + **inputs, + max_new_tokens=cfg.max_tokens, + do_sample=cfg.temperature is not None and cfg.temperature > 0.0, + temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None, + top_p=cfg.top_p, + eos_token_id=self._build_eos_ids(cfg.stop), + stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None, + pad_token_id=self.tok.pad_token_id, + ) + + # Decode only the new tokens + generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + return generated_text + + def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]: + """Stream text generation - uses KV cache if session_id provided, otherwise fallback to standard streaming""" + session_id = kwargs.pop('session_id', 'default') # Remove from kwargs to avoid duplicate + if session_id: + yield from self.stream_with_cache(prompt, session_id, cfg, **kwargs) + return + + # Fallback to standard streaming without KV cache + import threading + + enc = self.tok(prompt, return_tensors="pt").to(self.model.device) + streamer = self.TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True) + + def _worker(): + with self.torch.no_grad(): + self.model.generate( + **enc, + max_new_tokens=cfg.max_tokens, + do_sample=cfg.temperature is not None and cfg.temperature > 0.0, + temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None, + top_p=cfg.top_p, + eos_token_id=self._build_eos_ids(cfg.stop), + streamer=streamer, + pad_token_id=self.tok.pad_token_id, + ) + + t = threading.Thread(target=_worker) + t.start() + + for text in streamer: + yield text + + def tokenize(self, text: str) -> List[int]: + return self.tok.encode(text, add_special_tokens=False) + + def detokenize(self, ids: List[int]) -> str: + return self.tok.decode(ids, skip_special_tokens=True) + + def get_session(self, session_id: str = "default") -> ChatSession: + """Get or create a chat session for KV caching""" + # Create session config with the correct context size + from llm_runtime.chat_session import ChatSessionConfig + session_config = ChatSessionConfig(max_context_length=self.n_ctx) + return self.session_manager.get_session(session_id, config=session_config) + + def _prefill_phase(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', + cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]: + """Prefill phase: process the full prompt and return logits + past_key_values""" + with self.torch.no_grad(): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=True, + return_dict=True + ) + return outputs.logits, outputs.past_key_values + + def _prefill_incremental(self, new_input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', + past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]: + """Incremental prefill: process only new tokens with existing KV cache""" + with self.torch.no_grad(): + outputs = self.model( + input_ids=new_input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + return_dict=True + ) + return outputs.logits, outputs.past_key_values + + def _decode_step(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', + past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]: + """Single decode step: generate next token with KV cache""" + with self.torch.no_grad(): + outputs = self.model( + input_ids=input_ids, # Should be shape [1, 1] for single token + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + return_dict=True + ) + return outputs.logits, outputs.past_key_values + + def _sample_token(self, logits: 'torch.Tensor', cfg: GenerateConfig) -> int: + """Sample next token from logits based on generation config with optimized sampling""" + # Get logits for the last token + next_token_logits = logits[0, -1, :] + + # Apply temperature scaling + if cfg.temperature is not None and cfg.temperature != 1.0 and cfg.temperature > 0: + next_token_logits = next_token_logits / cfg.temperature + + # Apply top-p (nucleus) sampling if specified + if cfg.temperature is not None and cfg.temperature > 0.0 and cfg.top_p is not None and cfg.top_p < 1.0: + sorted_logits, sorted_indices = self.torch.sort(next_token_logits, descending=True) + cumulative_probs = self.torch.cumsum(self.torch.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > cfg.top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) + next_token_logits[indices_to_remove] = float('-inf') + + if cfg.temperature is not None and cfg.temperature > 0.0: + # Sample from distribution + probs = self.torch.softmax(next_token_logits, dim=-1) + next_token = self.torch.multinomial(probs, num_samples=1) + else: + # Greedy sampling (deterministic) + next_token = self.torch.argmax(next_token_logits, dim=-1, keepdim=True) + + return next_token.item() + + def _should_stop(self, token_id: int, generated_text: str, cfg: GenerateConfig) -> bool: + """Check if generation should stop""" + # Check for EOS token + if token_id == self.tok.eos_token_id: + return True + + # Check for custom stop strings + if cfg.stop: + for stop_str in cfg.stop: + if stop_str in generated_text: + return True + + return False + + def generate_with_cache(self, prompt: str, session_id: str = "default", + cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str: + """Generate text with persistent KV cache using manual generation loop""" + # Set defaults for None values + if cfg.temperature is None: + cfg.temperature = 0.7 # Default temperature + if cfg.top_p is None: + cfg.top_p = 0.95 # Default top_p + if cfg.max_tokens is None: + cfg.max_tokens = 500 # Default max tokens + + session = self.get_session(session_id) + + # Tokenize input + inputs = self.tok(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to(self.model.device) + attention_mask = inputs.attention_mask.to(self.model.device) + + generated_tokens = [] + max_new_tokens = cfg.max_tokens + + # Check if we need to invalidate cache with proper token validation + if session.should_invalidate(prompt, self.tok): + session.invalidate_cache() + print(f"[KV_CACHE] Cache invalidated for session {session_id}") + + # Determine if we need prefill phase + if session.past_key_values is None: + print(f"[KV_CACHE] Running prefill phase for {len(input_ids[0])} tokens") + # Prefill phase: process full prompt + logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg) + + # Update session cache + session.update_cache(past_key_values, input_ids[0].tolist(), prompt) + + # Sample first token + next_token_id = self._sample_token(logits, cfg) + generated_tokens.append(next_token_id) + + # Update input_ids and attention_mask for decode phase + current_length = input_ids.shape[1] + else: + print(f"[KV_CACHE] Using cached KV state, context length: {session.context_length}") + # Use cached state + past_key_values = session.past_key_values + current_length = session.context_length + + # For cached state, we still need to process the new part of the prompt if it exists + cached_length = len(session.cached_input_ids) + if input_ids.shape[1] > cached_length: + print(f"[KV_CACHE] Processing {input_ids.shape[1] - cached_length} new tokens (incremental prefill)") + new_tokens = input_ids[:, cached_length:] + + # Process only the new tokens with existing KV cache + # Create attention mask that covers both cached and new tokens + total_length = cached_length + new_tokens.shape[1] + extended_attention = self.torch.ones((1, total_length), device=self.model.device) + + logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg) + session.update_cache(past_key_values, input_ids[0].tolist(), prompt) + current_length = input_ids.shape[1] + else: + # No new tokens to process - get initial logits for generation + # Use a forward pass with the last token to get proper logits distribution + last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device) + total_length = current_length + 1 + next_attention = self.torch.ones((1, total_length), device=self.model.device) + + logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg) + + # Sample first new token + next_token_id = self._sample_token(logits, cfg) + generated_tokens.append(next_token_id) + + # Decode phase: generate tokens one by one with optimized attention management + for step in range(max_new_tokens - 1): # -1 because we already generated first token + # Check stop conditions + generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True) + if self._should_stop(next_token_id, generated_text, cfg): + break + + # Prepare inputs for next step - only pass the new token + next_input = self.torch.tensor([[next_token_id]], device=self.model.device) + total_length = current_length + len(generated_tokens) + 1 + next_attention = self.torch.ones((1, total_length), device=self.model.device) + + # Generate next token using cached KV states + logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg) + next_token_id = self._sample_token(logits, cfg) + generated_tokens.append(next_token_id) + + # Update session with final state - include all processed tokens + final_input_ids = input_ids[0].tolist() + generated_tokens + session.update_cache(past_key_values, final_input_ids, prompt + self.tok.decode(generated_tokens, skip_special_tokens=True)) + + # Decode generated tokens + generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True) + print(f"[KV_CACHE] Generated {len(generated_tokens)} tokens with KV cache") + + return generated_text + + def stream_with_cache(self, prompt: str, session_id: str = "default", + cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]: + """Stream generation with persistent KV cache""" + session = self.get_session(session_id) + + # Tokenize input + inputs = self.tok(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to(self.model.device) + attention_mask = inputs.attention_mask.to(self.model.device) + + generated_tokens = [] + max_new_tokens = cfg.max_tokens + + # Check if we need to invalidate cache with proper token validation + if session.should_invalidate(prompt, self.tok): + session.invalidate_cache() + print(f"[KV_CACHE] Cache invalidated for session {session_id}") + + # Determine if we need prefill phase + if session.past_key_values is None: + print(f"[KV_CACHE] Streaming prefill phase for {len(input_ids[0])} tokens") + # Prefill phase: process full prompt + logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg) + + # Update session cache + session.update_cache(past_key_values, input_ids[0].tolist(), prompt) + + # Sample first token + next_token_id = self._sample_token(logits, cfg) + generated_tokens.append(next_token_id) + current_length = input_ids.shape[1] + + # Yield first token + first_text = self.tok.decode([next_token_id], skip_special_tokens=True) + if first_text: + yield first_text + else: + print(f"[KV_CACHE] Streaming with cached KV state, context length: {session.context_length}") + # Use cached state - optimized logic similar to generate_with_cache + past_key_values = session.past_key_values + current_length = session.context_length + + # Handle new tokens in prompt with incremental prefill + cached_length = len(session.cached_input_ids) + if input_ids.shape[1] > cached_length: + print(f"[KV_CACHE] Streaming incremental prefill for {input_ids.shape[1] - cached_length} new tokens") + new_tokens = input_ids[:, cached_length:] + + # Process only new tokens with existing KV cache + total_length = cached_length + new_tokens.shape[1] + extended_attention = self.torch.ones((1, total_length), device=self.model.device) + + logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg) + session.update_cache(past_key_values, input_ids[0].tolist(), prompt) + current_length = input_ids.shape[1] + else: + # No new tokens - use last cached token for initial generation + last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device) + total_length = current_length + 1 + next_attention = self.torch.ones((1, total_length), device=self.model.device) + + logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg) + + # Sample first token + next_token_id = self._sample_token(logits, cfg) + generated_tokens.append(next_token_id) + + # Yield first token + first_text = self.tok.decode([next_token_id], skip_special_tokens=True) + if first_text: + yield first_text + + # Decode phase: stream tokens one by one with optimized KV cache usage + for step in range(max_new_tokens - 1): + # Check stop conditions + generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True) + if self._should_stop(next_token_id, generated_text, cfg): + break + + # Prepare inputs for next step - efficient single token processing + next_input = self.torch.tensor([[next_token_id]], device=self.model.device) + total_length = current_length + len(generated_tokens) + 1 + next_attention = self.torch.ones((1, total_length), device=self.model.device) + + # Generate next token using cached KV states + logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg) + next_token_id = self._sample_token(logits, cfg) + generated_tokens.append(next_token_id) + + # Yield the new token immediately + token_text = self.tok.decode([next_token_id], skip_special_tokens=True) + if token_text: + yield token_text + + # Update session with final state - include all processed tokens + final_input_ids = input_ids[0].tolist() + generated_tokens + full_generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True) + session.update_cache(past_key_values, final_input_ids, prompt + full_generated_text) + + print(f"[KV_CACHE] Streamed {len(generated_tokens)} tokens with persistent KV cache") + + def clear_session_cache(self, session_id: str = "default") -> None: + """Clear KV cache for a specific session""" + session = self.get_session(session_id) + session.invalidate_cache() + print(f"[DEBUG] Cleared cache for session {session_id}") + + def get_session_info(self, session_id: str = "default") -> Dict[str, Any]: + """Get information about a chat session""" + session = self.get_session(session_id) + return session.get_cache_info() + + def add_conversation_turn(self, user_message: str, assistant_message: str, + session_id: str = "default") -> None: + """Add a complete conversation turn to the session history""" + session = self.get_session(session_id) + session.add_message("user", user_message) + session.add_message("assistant", assistant_message) + + def get_model_info(self) -> Dict[str, Any]: + """Get comprehensive information about the loaded model""" + try: + # Basic model info + info = { + "model_name": getattr(self.model.config, 'name_or_path', 'unknown'), + "model_type": self.model.config.model_type, + "vocab_size": self.model.config.vocab_size, + "device": str(self.device), + "dtype": str(self.model.dtype), + "supports_kv_cache": True, + "max_position_embeddings": getattr(self.model.config, 'max_position_embeddings', 'unknown'), + "torch_compile_enabled": hasattr(self.model, '_orig_mod') + } + + # Memory info + if self.torch.cuda.is_available() and str(self.device) != 'cpu': + info.update({ + "gpu_memory_allocated": f"{self.torch.cuda.memory_allocated() / 1024**3:.2f} GB", + "gpu_memory_reserved": f"{self.torch.cuda.memory_reserved() / 1024**3:.2f} GB", + "gpu_memory_total": f"{self.torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB" + }) + + # Model parameters + total_params = sum(p.numel() for p in self.model.parameters()) + trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + info.update({ + "total_parameters": f"{total_params:,}", + "trainable_parameters": f"{trainable_params:,}", + "model_size_mb": f"{total_params * 4 / 1024**2:.2f} MB" # Assuming float32 + }) + + return info + except Exception as e: + return {"error": f"Could not get model info: {e}", "supports_kv_cache": True} + + def get_kv_cache_stats(self) -> Dict[str, Any]: + """Get KV cache statistics across all sessions""" + try: + sessions = self.session_manager.get_all_sessions() + stats = { + "total_sessions": len(sessions), + "active_sessions": 0, + "total_cached_tokens": 0, + "memory_usage_estimate": "0 MB" + } + + for session_id in sessions: + session = self.session_manager.get_session(session_id) + if session.past_key_values is not None: + stats["active_sessions"] += 1 + stats["total_cached_tokens"] += session.context_length + + # Rough estimate of KV cache memory usage + # Each token in KV cache roughly uses: hidden_size * num_layers * 2 (key + value) * 4 bytes (float32) + if stats["total_cached_tokens"] > 0: + try: + hidden_size = self.model.config.hidden_size + num_layers = self.model.config.num_hidden_layers + memory_bytes = stats["total_cached_tokens"] * hidden_size * num_layers * 2 * 4 + stats["memory_usage_estimate"] = f"{memory_bytes / 1024**2:.2f} MB" + except: + pass + + return stats + except Exception as e: + return {"error": f"Could not get KV cache stats: {e}"} + + def warm_up_model(self, test_prompt: str = "Hello") -> Dict[str, float]: + """Warm up the model and measure performance metrics""" + try: + print("[KV_CACHE] Warming up model...") + start_time = time.time() + + # Simple generation to warm up CUDA kernels + cfg = GenerateConfig(max_tokens=5, temperature=0.0) + _ = self.generate(test_prompt, cfg=cfg, session_id="warmup") + + warmup_time = time.time() - start_time + + # Clean up warmup session + self.clear_session_cache("warmup") + + return { + "warmup_time": warmup_time, + "status": "success" + } + except Exception as e: + return { + "warmup_time": 0.0, + "status": "failed", + "error": str(e) + } + +class HFTransformersLoader: + name = "hf" + + def can_load(self, source: str, **kwargs: Any) -> bool: + # Accept HF repo-id or local dir with config.json (covers .safetensors) + if "/" in source and not os.path.exists(source): # repo-id like "microsoft/DialoGPT-medium" + return True + + # Check if it's a directory with config.json + if os.path.isdir(source) and os.path.exists(os.path.join(source, "config.json")): + return True + + # If it's a single .safetensors file, look for config.json in the same directory + if source.lower().endswith('.safetensors'): + parent_dir = os.path.dirname(source) + return os.path.exists(os.path.join(parent_dir, "config.json")) + + # Also support .bin files (PyTorch checkpoints) + if source.lower().endswith('.bin'): + parent_dir = os.path.dirname(source) + return os.path.exists(os.path.join(parent_dir, "config.json")) + + return False + + def load(self, source: str, **kwargs: Any) -> UnifiedModel: + return _HFUnified(source, **kwargs) \ No newline at end of file diff --git a/llm_runtime/model_router.py b/llm_runtime/model_router.py new file mode 100644 index 0000000..7fcfca0 --- /dev/null +++ b/llm_runtime/model_router.py @@ -0,0 +1,55 @@ +from __future__ import annotations +import os, json +from pathlib import Path +from typing import Literal, Optional, Dict, Tuple, Callable + +# Only support Hugging Face (Safetensors) and GGUF +LoaderKind = Literal["hf", "gguf"] + +def detect_loader_type(source: str) -> Tuple[LoaderKind, str]: + """ + Decide which loader to use based on the path/repo. + Returns: (kind, reason) + kind ∈ {"hf","gguf"} + """ + print(f"[ROUTER_DEBUG] detect_loader_type() called with source: '{source}'") + p = Path(source) + low = source.lower() + + # 0) Force HF for official Meta repos + if not p.exists() and low.startswith("meta-llama/"): + result = ("hf", "Official Meta repo (full-precision).") + print(f"[ROUTER_DEBUG] Meta repo detected -> {result}") + return result + + # 1) Local FILE + if p.exists() and p.is_file(): + if p.suffix.lower() == ".gguf": + return "gguf", "Local .gguf file." + return "hf", "Local non-.gguf file (default HF)." + + # 2) Local DIR + if p.exists() and p.is_dir(): + # GGUF hint: any *.gguf inside + if any(p.glob("*.gguf")): + return "gguf", "Directory contains .gguf file(s)." + # Inspect config.json only for GGUF hints; all else defaults to HF + cfg = p / "config.json" + if cfg.exists(): + try: + data = json.loads(cfg.read_text(encoding="utf-8")) + text = json.dumps(data).lower() + if ("gguf" in text) or ("llama.cpp" in text) or ("ggml" in text): + return "gguf", "config.json mentions GGUF/llama.cpp." + except Exception: + pass + return "hf", "Default HF for non-GGUF directories." + + # 3) Remote repo style (org/name) + if ("/" in source or source.count("\\") == 1) and not p.exists(): + if any(tag in low for tag in ["gguf", "ggml", "llama.cpp"]): + return "gguf", "Repo name suggests GGUF." + return "hf", "Remote repo (default HF)." + + # 4) Fallback + return "hf", "Fallback to HF." \ No newline at end of file diff --git a/llm_runtime/registry.py b/llm_runtime/registry.py new file mode 100644 index 0000000..e40c2a2 --- /dev/null +++ b/llm_runtime/registry.py @@ -0,0 +1,10 @@ +from typing import Any +from .types import UnifiedModel +from .loader_factory import load_model_for_gui + +def load_model(source: str, **kwargs: Any) -> UnifiedModel: + """Load model using the router-based factory system""" + print(f"[REGISTRY_DEBUG] load_model() called with source='{source}', kwargs={kwargs}") + model, kind, reason = load_model_for_gui(source, **kwargs) + print(f"[REGISTRY_DEBUG] load_model_for_gui() returned: kind='{kind}', reason='{reason}'") + return model \ No newline at end of file diff --git a/llm_runtime/types.py b/llm_runtime/types.py new file mode 100644 index 0000000..49b484d --- /dev/null +++ b/llm_runtime/types.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import Iterable, Iterator, List, Optional, Protocol, Any + +@dataclass +class GenerateConfig: + # None means "use model defaults" + max_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + stop: Optional[Iterable[str]] = None + # Optional context window hint; models may ignore if they manage it internally + context_window: Optional[int] = None + # Optional advanced knobs; models may ignore if unsupported + min_p: Optional[float] = None + typical_p: Optional[float] = None + repetition_penalty: Optional[float] = None + +class UnifiedModel(Protocol): + # Default config uses None for all tunables so the model can choose + def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str: ... + def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]: ... + def tokenize(self, text: str) -> List[int]: ... + def detokenize(self, ids: List[int]) -> str: ... \ No newline at end of file diff --git a/llm_runtime/util_chat.py b/llm_runtime/util_chat.py new file mode 100644 index 0000000..da77072 --- /dev/null +++ b/llm_runtime/util_chat.py @@ -0,0 +1,23 @@ +from typing import Any, Dict, List, Optional + +def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str: + """Apply chat template to messages using tokenizer or fallback to ChatML format""" + if hasattr(tokenizer, "apply_chat_template"): + return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt) + + # Fallback: simple ChatML-ish format + parts = [] + for m in messages: + role = m.get("role", "user") + content = m.get("content", "") + if role == "system": + parts.append(f"<|system|>\n{content}\n") + elif role == "assistant": + parts.append(f"<|assistant|>\n{content}\n") + else: + parts.append(f"<|user|>\n{content}\n") + + if add_generation_prompt: + parts.append("<|assistant|>\n") + + return "".join(parts) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..5424cf1 --- /dev/null +++ b/main.py @@ -0,0 +1,3093 @@ +import argparse +import sys +import threading +import time +import tkinter as tk +import traceback +import linecache +from tkinter import ttk, filedialog, messagebox +from typing import List, Dict, Any, Optional +import queue +import subprocess +import os +from dotenv import load_dotenv +import requests +from settings_manager import SettingsManager, open_settings_dialog +from grouped_download_manager import GroupedDownloadManager, FileSelectionDialog +from grouped_download_gui import GroupedDownloadManagerTab +from model_library import ModelLibraryTab +from mcp_config import open_mcp_config +from splash_screen import SplashManager +from mcp_tab import MCPTab +from model_converter import ModelConverterTab +from chess_tab import ChessTab +from chat_templates import get_template_manager, ChatTemplateDialog + +# from finetune_tab import FineTuneTab # Temporarily disabled for debugging + +# Try to import torch for GPU functionality +try: + import torch +except ImportError: + torch = None + +# Load from your custom env file +load_dotenv("HUGGINGFACE.env") + +# Global execution tracer +_trace_enabled = False +_trace_filters = ['llm_runtime', 'main.py', 'autogptq', 'transformers'] + +def execution_tracer(frame, event, arg): + """Trace every line of code execution during model loading""" + if not _trace_enabled: + return + + if event == 'line': + filename = frame.f_code.co_filename + lineno = frame.f_lineno + + # Only trace files we care about + if any(filter_term in filename for filter_term in _trace_filters): + try: + line = linecache.getline(filename, lineno).strip() + short_filename = filename.split('/')[-1] if '/' in filename else filename.split('\\')[-1] + print(f"[TRACE] {short_filename}:{lineno} | {line}") + except: + pass + + return execution_tracer + +def start_tracing(): + """Start execution tracing""" + global _trace_enabled + _trace_enabled = True + print("[TRACE] Execution tracing STARTED") + +def stop_tracing(): + """Stop execution tracing""" + global _trace_enabled + _trace_enabled = False + print("[TRACE] Execution tracing STOPPED") + +# Access the key +hf_key = os.getenv("HF_API_KEY") + +# Embedded llama.cpp (self-contained, no external daemons) +try: + from llama_cpp import Llama + + +except Exception: + print( + "The 'llama-cpp-python' package is required. Please install dependencies with: pip install -r requirements.txt", + file=sys.stderr) + raise + +# ChessGPT support +try: + import torch + import transformers + from transformers import AutoTokenizer, AutoModelForCausalLM + + TRANSFORMERS_AVAILABLE = True + MIN_TRANSFORMERS_VERSION = '4.25.1' + if transformers.__version__ < MIN_TRANSFORMERS_VERSION: + print( + f"Warning: transformers version {transformers.__version__} may not be compatible. Recommended: {MIN_TRANSFORMERS_VERSION}+") +except ImportError: + TRANSFORMERS_AVAILABLE = False + print("Warning: transformers not available. Chess mode will not work. Install with: pip install torch transformers") + +_llama_cache = { + "key": None, # (model_path, lora_path, n_ctx, n_gpu_layers, n_threads) + "llm": None, +} + +# ChessGPT model cache +_chessgpt_cache = { + "tokenizer": None, + "model": None, + "loaded": False +} + +# Serialize llama native calls when sharing one instance across threads +_LLAMA_LOCK = threading.RLock() + +# Default system prompt and anti-echo stop tokens +DEFAULT_SYSTEM_PROMPT = ( + "You are a helpful, concise assistant. Answer the user's question directly. " + "Do not repeat or paraphrase the user's prompt; provide only your answer." +) +# Conservative defaults to reduce echo +DEFAULT_TEMPERATURE = 0.2 +DEFAULT_MAX_TOKENS = 256 +# Add common model chat-template markers to stop tokens to avoid template echoes like [INST] <> ... +STOP_TOKENS = [ + "\nUser:", "\nYou:", "User:", "You:", + "<|user|>", "<|assistant|>", "<|eot_id|>", "<|eom_id|>", + "[INST]", "[/INST]", "<>", "", "<>", ">" +] + + +def _strip_echo_from_response(text: str, last_user_prompt: Optional[str]) -> str: + try: + s = text or "" + # Remove the last user prompt if the model echoed it at the start + if last_user_prompt: + lu = (last_user_prompt or "").strip() + if lu and s.strip().startswith(lu): + # Cut the first occurrence + idx = s.find(lu) + if idx == 0: + s = s[len(lu):] + # Remove common stop tokens/templates that may leak + for tok in STOP_TOKENS: + s = s.replace(tok, " ") + # Clean up repeated whitespace + s = " ".join(s.split()) + return s.strip() + except Exception: + return text or "" + + +# Provide safe fallbacks so static analysis / early references don't error. +# If real implementations are defined later they will overwrite these. +try: + _extract_gguf_metadata # type: ignore +except NameError: + def _extract_gguf_metadata(path: str, key: str) -> Optional[str]: + # Non-invasive fallback: best-effort no-op that returns None + return None + +try: + _extract_gguf_int_metadata # type: ignore +except NameError: + def _extract_gguf_int_metadata(path: str, key: str) -> Optional[int]: + # Non-invasive fallback: best-effort no-op that returns None + return None + + +def _is_gguf_model(path: str) -> bool: + if not path: + return False + p = path.strip().strip('"') + return os.path.isfile(p) and p.lower().endswith('.gguf') + + +def _is_valid_model(path: str) -> bool: + """Check if the given path is a valid model file of any supported format.""" + if not path: + return False + + p = path.strip().strip('"') + + # Check if it might be a HuggingFace repo ID + if "/" in p and not os.path.exists(p): + return True # Let the loader validate it + + if not os.path.exists(p): + return False + + # Check for supported file extensions + p_lower = p.lower() + + # GGUF format (llama.cpp) + if p_lower.endswith('.gguf'): + return True + + # SafeTensors format (HuggingFace) + if p_lower.endswith('.safetensors'): + return True + + # PyTorch format + if p_lower.endswith('.bin') or p_lower.endswith('.pt') or p_lower.endswith('.pth'): + return True + + # GPTQ quantized models + if 'gptq' in p_lower and (p_lower.endswith('.safetensors') or p_lower.endswith('.bin')): + return True + + # AWQ quantized models + if 'awq' in p_lower and (p_lower.endswith('.safetensors') or p_lower.endswith('.bin')): + return True + + # EXL2 format + if p_lower.endswith('.exl2'): + return True + + # Check if it's a directory with model files + if os.path.isdir(p): + # Check for standard HuggingFace structure + config_path = os.path.join(p, "config.json") + if os.path.exists(config_path): + return True + + # Check for GPTQ models + if any(f for f in os.listdir(p) if 'gptq' in f.lower() and (f.endswith('.safetensors') or f.endswith('.bin'))): + return True + + # Check for AWQ models + if any(f for f in os.listdir(p) if 'awq' in f.lower() and (f.endswith('.safetensors') or f.endswith('.bin'))): + return True + + # Check for any model files + if any(f for f in os.listdir(p) if f.endswith(('.gguf', '.safetensors', '.bin', '.pt', '.pth', '.exl2'))): + return True + + return False + + +def _get_llama(model_path: str, n_ctx: int = 4096, n_gpu_layers: int = 0, lora_path: Optional[str] = None, + n_threads: Optional[int] = None) -> "Llama": + mp = model_path.strip().strip('"') + if n_threads is None or n_threads <= 0: + n_threads = max(1, os.cpu_count() or 1) + + # include file modification time so cache invalidates when model file changes + try: + mtime = os.path.getmtime(mp) + except Exception: + mtime = None + + # Normalize inputs to ints for cache key + try: + n_ctx_int = int(n_ctx) if n_ctx is not None else 0 + except Exception: + n_ctx_int = 0 + try: + n_gpu_int = int(n_gpu_layers) if n_gpu_layers is not None else 0 + except Exception: + n_gpu_int = 0 + + # include n_ctx and n_gpu_layers in key so different contexts create separate instances + key = (mp, lora_path, n_ctx_int, n_gpu_int, int(n_threads), mtime) + if _llama_cache["llm"] is not None and _llama_cache["key"] == key: + return _llama_cache["llm"] + + # Pass the requested context size and gpu layers to Llama so it uses the correct capacity. + # If the caller passes 0 for n_ctx, the underlying library will use the model's trained n_ctx. + print(f"[GGUF_DEBUG] Loading GGUF model with n_ctx={n_ctx_int}, n_gpu_layers={n_gpu_int}") + llm = Llama( + model_path=mp, + n_ctx=n_ctx_int, + n_gpu_layers=n_gpu_int, + lora_path=lora_path, + n_threads=n_threads, + verbose=False, + ) + print(f"[GGUF_DEBUG] GGUF model loaded successfully with GPU layers: {n_gpu_int}") + _llama_cache["key"] = key + _llama_cache["llm"] = llm + return llm + + +def _get_chessgpt(): + """Get the already-loaded ChessGPT GGUF model.""" + # The ChessGPT GGUF is already loaded as the main model + # We don't need to download anything - just return a flag + # The actual model is accessed through _get_llama() + return None, None # Return None since we're using the GGUF version + + +def _run_chessgpt_prompt(prompt: str, model_path: str = None, on_chunk: Optional[callable] = None, + max_tokens: int = 128) -> str: + """Run a prompt through ChessGPT GGUF model using the ChessGPT conversation format.""" + try: + # Format prompt for ChessGPT conversation style + # Add explicit instruction to return only the move + chess_prompt = f"A friendly, helpful chat between some humans.<|endoftext|>Human 0: {prompt}\nRespond with ONLY the chess move in UCI format (like e2e4).<|endoftext|>Human 1:" + + # Use the already-loaded GGUF model with proper context size + # ChessGPT was trained on 2048 context + llm = _get_llama(model_path, n_ctx=2048, n_gpu_layers=32) + + # Generate response using the GGUF model with lower temperature for more deterministic moves + with _LLAMA_LOCK: + response = llm( + chess_prompt, + max_tokens=20, # Reduced - we only need a move + temperature=0.3, # Lower temperature for more deterministic chess moves + top_p=0.9, + top_k=40, + echo=False, + stop=["<|endoftext|>", "Human 0:", "Human 1:", "\n"] + ) + + output_str = response['choices'][0]['text'].strip() + + print(f"[CHESS DEBUG] ChessGPT raw response: '{output_str}'") + + # Stream output if callback provided + if on_chunk: + for char in output_str: + on_chunk(char) + + return output_str + + except Exception as e: + print(f"ChessGPT GGUF generation failed: {e}") + raise + + +def run_prompt(model_path: str, prompt: str, stream: bool, n_ctx: int = 4096, n_gpu_layers: int = 0, + lora_path: Optional[str] = None, on_chunk: Optional[callable] = None, n_threads: Optional[int] = None, + max_tokens: Optional[int] = None, history: Optional[List[Dict[str, Any]]] = None, + cancel_event: Optional[threading.Event] = None, chess_mode: bool = False) -> str: + # Use ChessGPT GGUF if chess mode is enabled + if chess_mode: + try: + return _run_chessgpt_prompt(prompt, model_path=model_path, on_chunk=on_chunk, max_tokens=max_tokens or 128) + except Exception as e: + print(f"ChessGPT GGUF failed, falling back to regular model: {e}") + # Fall through to regular model with limited context + n_ctx = min(n_ctx, 2048) # Limit context for chess to avoid overflow + + llm = _get_llama(model_path, n_ctx=n_ctx, n_gpu_layers=n_gpu_layers, lora_path=lora_path, n_threads=n_threads) + # Build messages from prior history plus the new user message + messages: List[Dict[str, Any]] = [] + if not history or (history and history[0].get("role") != "system"): + messages.append({"role": "system", "content": DEFAULT_SYSTEM_PROMPT}) + if history: + messages.extend(history) + messages.append({"role": "user", "content": prompt}) + max_new_tokens = DEFAULT_MAX_TOKENS if (max_tokens is None or int(max_tokens) <= 0) else int(max_tokens) + out_parts: List[str] = [] + if stream: + with _LLAMA_LOCK: + for part in llm.create_chat_completion(messages=messages, stream=True, stop=STOP_TOKENS, + temperature=DEFAULT_TEMPERATURE, repeat_penalty=1.2, + max_tokens=max_new_tokens): + if cancel_event is not None and cancel_event.is_set(): + break + try: + chunk = part["choices"][0]["delta"].get("content", "") + except Exception: + chunk = part.get("choices", [{}])[0].get("message", {}).get("content", "") + if chunk: + out_parts.append(chunk) + if on_chunk: + on_chunk(chunk) + if cancel_event is not None and cancel_event.is_set(): + break + result = "".join(out_parts) + # Try to remove prompt echo/templates using the last user prompt + last_user = prompt + return _strip_echo_from_response(result, last_user) + else: + with _LLAMA_LOCK: + res = llm.create_chat_completion(messages=messages, stop=STOP_TOKENS, temperature=DEFAULT_TEMPERATURE, + repeat_penalty=1.2, max_tokens=max_new_tokens) + raw = res.get("choices", [{}])[0].get("message", {}).get("content", "") or "" + # Remove common echoes: use the prompt as the last user message + return _strip_echo_from_response(raw, prompt) + + +def chat_stream(model_path: str, messages: List[Dict[str, Any]], n_ctx: int = 4096, n_gpu_layers: int = 0, + lora_path: Optional[str] = None, on_chunk: Optional[callable] = None, n_threads: Optional[int] = None, + max_tokens: Optional[int] = None, cancel_event: Optional[threading.Event] = None, + chess_mode: bool = False, chat_template: Optional[str] = None, session_id: Optional[str] = None) -> str: + # Use ChessGPT if chess mode is enabled + if chess_mode: + try: + # Extract the last user message for ChessGPT + last_user_message = "" + for msg in reversed(messages): + if msg.get("role") == "user": + last_user_message = msg.get("content", "") + break + if last_user_message: + return _run_chessgpt_prompt(last_user_message, on_chunk=on_chunk, max_tokens=max_tokens or 128) + except Exception as e: + print(f"ChessGPT failed, falling back to regular model: {e}") + # Fall through to regular model + + # Ensure a system message exists at the start + if not messages or messages[0].get("role") != "system": + messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}] + list(messages) + collected: List[str] = [] + max_new_tokens = DEFAULT_MAX_TOKENS if (max_tokens is None or int(max_tokens) <= 0) else int(max_tokens) + + # Use appropriate loader based on model type + if _is_gguf_model(model_path): + # Use existing GGUF loader + llm = _get_llama(model_path, n_ctx=n_ctx, n_gpu_layers=n_gpu_layers, lora_path=lora_path, n_threads=n_threads) + with _LLAMA_LOCK: + for part in llm.create_chat_completion(messages=messages, stream=True, stop=STOP_TOKENS, + temperature=DEFAULT_TEMPERATURE, repeat_penalty=1.2, + max_tokens=max_new_tokens): + if cancel_event is not None and cancel_event.is_set(): + break + try: + delta = part["choices"][0]["delta"].get("content", "") + except Exception: + delta = part.get("choices", [{}])[0].get("message", {}).get("content", "") + if delta: + collected.append(delta) + if on_chunk: + on_chunk(delta) + if cancel_event is not None and cancel_event.is_set(): + break + else: + # Use unified loader for other model types - reuse cached model + from llm_runtime import GenerateConfig + + # Check if model is already loaded in cache (from _on_load_model) + if not hasattr(chat_stream, '_unified_model_cache') or chat_stream._unified_model_cache is None: + raise RuntimeError("Model not loaded. Please load a model first using the 'Load Model' button.") + + llm = chat_stream._unified_model_cache + print("DEBUG: Using cached unified model") + + # Convert messages to prompt format using chat template + if chat_template and chat_template != "None": + # Use the selected chat template + from chat_templates import get_template_manager + template_manager = get_template_manager() + prompt = template_manager.format_conversation(chat_template, messages, add_generation_prompt=True) + print(f"DEBUG: Using chat template '{chat_template}'") + else: + # Fallback to simple format for backward compatibility + prompt_parts = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "system": + prompt_parts.append(f"System: {content}") + elif role == "user": + prompt_parts.append(f"User: {content}") + elif role == "assistant": + prompt_parts.append(f"Assistant: {content}") + prompt = "\n".join(prompt_parts) + "\nAssistant:" + print("DEBUG: Using fallback User:/Assistant: format") + + print(f"DEBUG: Generated prompt: '{prompt}'") + + # Get appropriate stop tokens from template or use defaults + if chat_template and chat_template != "None": + from chat_templates import get_template_manager + template_manager = get_template_manager() + template_stop_tokens = template_manager.get_stop_tokens(chat_template) + stop_tokens = template_stop_tokens if template_stop_tokens else STOP_TOKENS + else: + stop_tokens = STOP_TOKENS + + # Generate with unified API and KV caching + cfg = GenerateConfig(max_tokens=max_new_tokens, temperature=DEFAULT_TEMPERATURE, top_p=0.9, stop=stop_tokens) + # Use instance-specific session ID to maintain conversation continuity while preventing cross-chat contamination + if session_id is None: + session_id = "default" # Fallback for CLI usage + + # Get session info before generation + if hasattr(llm, 'get_session_info'): + session_info = llm.get_session_info(session_id) + print(f"[KV_CACHE] Pre-generation session info: {session_info}") + + print(f"[CHAT] Starting streaming generation with KV caching enabled") + generation_start = time.time() + token_count = 0 + + for delta in llm.stream(prompt, cfg=cfg, session_id=session_id): + if cancel_event is not None and cancel_event.is_set(): + break + if delta: + collected.append(delta) + token_count += 1 + if on_chunk: + on_chunk(delta) + if cancel_event is not None and cancel_event.is_set(): + break + + generation_time = time.time() - generation_start + print(f"[KV_CACHE] Generated {token_count} tokens in {generation_time:.2f}s ({token_count/generation_time:.1f} tokens/s)") + + # Get session info after generation + if hasattr(llm, 'get_session_info'): + session_info = llm.get_session_info(session_id) + print(f"[KV_CACHE] Post-generation session info: {session_info}") + + # Get KV cache statistics if available + if hasattr(llm, 'get_kv_cache_stats'): + cache_stats = llm.get_kv_cache_stats() + print(f"[KV_CACHE] Cache statistics: {cache_stats}") + result = "".join(collected) + # Find the last user message in provided messages and attempt to strip echoes + last_user = None + try: + for msg in reversed(messages): + if msg.get("role") == "user": + last_user = msg.get("content", "") + break + except Exception: + last_user = None + return _strip_echo_from_response(result, last_user) + + +# ---------------- GUI (Embedded only) ----------------- +class EmbeddedGUI: + def __init__(self, root: tk.Tk): + print("[APP_DEBUG] EmbeddedGUI.__init__() started") + self.root = root + self.root.title("DarkHal 2.0 - AI Model Management Platform") + + # Set window icon + try: + icon_path = os.path.join(os.path.dirname(__file__), "assets", "Halico.ico") + if os.path.exists(icon_path): + self.root.iconbitmap(icon_path) + except Exception: + pass + + # Set minimum window size + self.root.minsize(1000, 700) + + # Initialize settings manager + self.settings_manager = SettingsManager() + + # Initialize grouped download manager + max_concurrent = self.settings_manager.get('download_settings.max_concurrent_downloads', 3) + self.download_manager = GroupedDownloadManager(max_concurrent=max_concurrent) + + # Initialize agent mode attributes + self.agent_enabled = False + self.dhal_agent = None + + # Create menu bar + self._create_menu_bar() + + # Load settings and initialize variables + self.model_var = tk.StringVar(value=self.settings_manager.get('paths.last_model_path', '')) + self.stream_var = tk.BooleanVar(value=self.settings_manager.get('model_settings.stream_by_default', True)) + self.n_ctx_var = tk.IntVar(value=self.settings_manager.get('model_settings.default_n_ctx', 4096)) + # Set default GPU layers - use higher default if GPU is available and auto-config is enabled + default_gpu_layers = self.settings_manager.get('model_settings.default_n_gpu_layers', 0) + if default_gpu_layers == 0 and self.settings_manager.get('model_settings.auto_gpu', True): + # If auto-GPU is enabled and no custom default is set, use a reasonable default for GPU systems + try: + import torch + if torch.cuda.is_available(): + default_gpu_layers = 16 # Reasonable default for most 7B models + except: + pass + self.n_gpu_layers_var = tk.IntVar(value=default_gpu_layers) + self.lora_var = tk.StringVar(value=self.settings_manager.get('paths.last_lora_path', '')) + self.model_status_var = tk.StringVar(value="[not loaded]") + self.max_tokens_var = tk.IntVar( + value=self.settings_manager.get('model_settings.default_max_tokens', DEFAULT_MAX_TOKENS)) + self.chess_mode_var = tk.BooleanVar(value=self.settings_manager.get('model_settings.chess_mode', False)) + self.agent_mode_var = tk.BooleanVar(value=False) # Agent mode always starts disabled for safety + + # Advanced loading options + self.quantization_var = tk.StringVar(value=self.settings_manager.get('model_settings.quantization', 'none')) + self.device_strategy_var = tk.StringVar(value=self.settings_manager.get('model_settings.device_strategy', 'auto')) + self.chat_template_var = tk.StringVar(value="None") + self.gpu_memory_limit_var = tk.DoubleVar(value=self.settings_manager.get('model_settings.gpu_memory_limit', 6.0)) + + # Sampling parameters + self.temperature_var = tk.DoubleVar(value=self.settings_manager.get('model_settings.temperature', 0.7)) + self.top_p_var = tk.DoubleVar(value=self.settings_manager.get('model_settings.top_p', 0.9)) + self.repetition_penalty_var = tk.DoubleVar(value=self.settings_manager.get('model_settings.repetition_penalty', 1.1)) + self.no_repeat_ngram_size_var = tk.IntVar(value=self.settings_manager.get('model_settings.no_repeat_ngram_size', 0)) + self.min_p_var = tk.DoubleVar(value=self.settings_manager.get('model_settings.min_p', 0.0)) + self.typical_p_var = tk.DoubleVar(value=self.settings_manager.get('model_settings.typical_p', 1.0)) + + # UI queue for thread-safe widget updates from worker threads + self._ui_queue: "queue.SimpleQueue[callable]" = queue.SimpleQueue() + self.root.after(30, self._drain_ui_queue) + + # Reset loaded status when key settings change + try: + self.model_var.trace_add("write", lambda *a: self._mark_model_unloaded()) + self.n_ctx_var.trace_add("write", lambda *a: self._mark_model_unloaded()) + self.n_gpu_layers_var.trace_add("write", lambda *a: self._mark_model_unloaded()) + self.lora_var.trace_add("write", lambda *a: self._mark_model_unloaded()) + self.chess_mode_var.trace_add("write", lambda *a: self._mark_model_unloaded()) + + # Advanced loading options - also mark model as unloaded and save settings + self.quantization_var.trace_add("write", lambda *a: self._on_advanced_setting_changed()) + self.device_strategy_var.trace_add("write", lambda *a: self._on_advanced_setting_changed()) + self.gpu_memory_limit_var.trace_add("write", lambda *a: self._on_advanced_setting_changed()) + + # Sampling parameters - save settings when changed + self.temperature_var.trace_add("write", lambda *a: self._on_sampling_setting_changed()) + self.top_p_var.trace_add("write", lambda *a: self._on_sampling_setting_changed()) + self.repetition_penalty_var.trace_add("write", lambda *a: self._on_sampling_setting_changed()) + self.no_repeat_ngram_size_var.trace_add("write", lambda *a: self._on_sampling_setting_changed()) + self.min_p_var.trace_add("write", lambda *a: self._on_sampling_setting_changed()) + self.typical_p_var.trace_add("write", lambda *a: self._on_sampling_setting_changed()) + except Exception: + try: + self.model_var.trace("w", lambda *a: self._mark_model_unloaded()) + self.n_ctx_var.trace("w", lambda *a: self._mark_model_unloaded()) + self.n_gpu_layers_var.trace("w", lambda *a: self._mark_model_unloaded()) + self.lora_var.trace("w", lambda *a: self._mark_model_unloaded()) + self.chess_mode_var.trace("w", lambda *a: self._mark_model_unloaded()) + except Exception: + pass + + # Local models support with settings + self.models_dir_var = tk.StringVar(value=self.settings_manager.get('paths.models_directory', './models')) + self.local_model_var = tk.StringVar() + self._local_model_paths: Dict[str, str] = {} + + nb = ttk.Notebook(root) + nb.pack(fill=tk.BOTH, expand=True) + + # Single Run tab + self.run_frame = ttk.Frame(nb) + nb.add(self.run_frame, text="Run") + self._build_run_tab(self.run_frame) + + + # Model Library tab + self.library_frame = ttk.Frame(nb) + nb.add(self.library_frame, text="Model Library") + self.library_tab = ModelLibraryTab(self.library_frame, self.settings_manager) + + # Model Converter tab + self.converter_frame = ttk.Frame(nb) + nb.add(self.converter_frame, text="Model Converter") + self.converter_tab = ModelConverterTab(self.converter_frame, self.settings_manager) + + # Chess tab + self.chess_frame = ttk.Frame(nb) + nb.add(self.chess_frame, text="Chess") + self.chess_tab = ChessTab(self.chess_frame, self.settings_manager) + + + # Fine Tune tab - temporarily disabled for debugging + # self.finetune_frame = ttk.Frame(nb) + # nb.add(self.finetune_frame, text="Fine Tune") + # self.finetune_tab = FineTuneTab(self.finetune_frame, self.settings_manager) + + + + + # Initialize local models list if a folder is preset + if self.models_dir_var.get(): + self._refresh_local_models() + + self.chat_history: List[Dict[str, Any]] = [] + self._current_cancel: Optional[threading.Event] = None + # Initialize unique session ID for KV cache isolation + import uuid + self._session_id = f"chat_session_{uuid.uuid4().hex[:8]}" + + # Initialize chat template manager + self.template_manager = get_template_manager() + self._refresh_chat_templates() + + # Apply window size from settings + width = self.settings_manager.get('ui_preferences.window_width', 1200) + height = self.settings_manager.get('ui_preferences.window_height', 700) + self.root.geometry(f"{width}x{height}") + + # Save window size on close + self.root.protocol("WM_DELETE_WINDOW", self._on_closing) + + def _create_menu_bar(self): + """Create the application menu bar.""" + menubar = tk.Menu(self.root) + self.root.config(menu=menubar) + + # File menu + file_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="File", menu=file_menu) + file_menu.add_command(label="Open Model...", command=self._browse_gguf) + file_menu.add_separator() + file_menu.add_command(label="Settings...", command=self._open_settings) + file_menu.add_separator() + file_menu.add_command(label="Exit", command=self._on_closing) + + # Edit menu + edit_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="Edit", menu=edit_menu) + edit_menu.add_command(label="Clear Output", command=lambda: self.output_text.delete('1.0', tk.END)) + edit_menu.add_command(label="Clear Chat History", command=self._clear_chat_history) + edit_menu.add_command(label="Clear KV Cache", command=self._clear_kv_cache) + + # Tools menu + tools_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="Tools", menu=tools_menu) + tools_menu.add_command(label="HuggingFace Downloader", command=self._open_hf_downloader) + tools_menu.add_command(label="Downloads Manager", command=self._open_downloads_manager) + tools_menu.add_command(label="MCP Server", command=self._open_mcp_server) + tools_menu.add_command(label="Resource Monitor", command=self._open_resource_monitor) + + # Debug submenu + debug_menu = tk.Menu(tools_menu, tearoff=0) + tools_menu.add_separator() + tools_menu.add_cascade(label="Debug", menu=debug_menu) + debug_menu.add_command(label="Inspect Model Devices", command=self._inspect_model_devices) + tools_menu.add_command(label="Refresh Local Models", command=self._refresh_local_models) + tools_menu.add_separator() + tools_menu.add_command(label="Clear Completed Downloads", command=self._clear_completed_downloads) + tools_menu.add_command(label="MCP Server Config", command=self._open_mcp_config) + + # Agents menu + agents_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="Agents", menu=agents_menu) + + # DarkHal submenu + darkhal_menu = tk.Menu(agents_menu, tearoff=0) + agents_menu.add_cascade(label="DarkHal", menu=darkhal_menu) + darkhal_menu.add_command(label="Dhal", command=self._open_dhal_agent) + darkhal_menu.add_command(label="Agent Dev Kit (ADK)", command=self._open_adk) + + # Metasploit option + agents_menu.add_command(label="Metasploit", command=self._open_metasploit) + + # Help menu + help_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="Help", menu=help_menu) + help_menu.add_command(label="About", command=self._show_about) + + def _open_settings(self): + """Open the settings dialog.""" + open_settings_dialog(self.root, self.settings_manager) + # Reload settings after dialog closes + self._reload_settings() + + def _reload_settings(self): + """Reload settings after they've been changed.""" + # Update variables from settings + self.n_ctx_var.set(self.settings_manager.get('model_settings.default_n_ctx', 4096)) + self.n_gpu_layers_var.set(self.settings_manager.get('model_settings.default_n_gpu_layers', 0)) + self.max_tokens_var.set(self.settings_manager.get('model_settings.default_max_tokens', DEFAULT_MAX_TOKENS)) + self.stream_var.set(self.settings_manager.get('model_settings.stream_by_default', True)) + + # Reload HF API if token settings changed + if hasattr(self, 'hf_api'): + try: + from hf_downloader import HuggingFaceAPI + api_key = None + organization = None + + if not self.settings_manager.get('api.use_env_token', True): + api_key = self.settings_manager.get('api.huggingface_token', '').strip() + + if self.settings_manager.get('api.use_organization', False): + organization = self.settings_manager.get('api.organization', '').strip() + + self.hf_api = HuggingFaceAPI(api_key=api_key, organization=organization) + except Exception: + pass + + def _clear_chat_history(self): + """Clear the chat history and invalidate KV cache session.""" + self.chat_history.clear() + self.output_text.delete('1.0', tk.END) + + # Clear KV cache for the current session to prevent contamination + try: + if hasattr(chat_stream, '_unified_model_cache') and chat_stream._unified_model_cache: + llm = chat_stream._unified_model_cache + if hasattr(llm, 'clear_session_cache'): + llm.clear_session_cache(self._session_id) + print(f"[KV_CACHE] Cleared cache for session {self._session_id}") + except Exception as e: + print(f"[KV_CACHE] Error clearing session cache: {e}") + + # Generate new session ID for fresh conversation + import uuid + self._session_id = f"chat_session_{uuid.uuid4().hex[:8]}" + + self._append_output("[Chat history cleared - KV cache reset]\n") + + def _refresh_chat_templates(self): + """Refresh the chat template dropdown with available templates""" + try: + template_names = ["None"] + self.template_manager.get_template_names() + self.chat_template_combo['values'] = template_names + + # Set to "None" if current selection is not available + current = self.chat_template_var.get() + if current not in template_names: + self.chat_template_var.set("None") + except Exception as e: + print(f"Error refreshing chat templates: {e}") + + def _load_chat_template(self): + """Load chat templates from file""" + try: + filename = filedialog.askopenfilename( + title="Load Chat Templates", + filetypes=[("JSON files", "*.json"), ("All files", "*.*")], + parent=self.root + ) + + if filename: + # Load templates from the selected file + import json + with open(filename, 'r', encoding='utf-8') as f: + data = json.load(f) + + loaded_count = 0 + for name, template_data in data.items(): + try: + from chat_templates import ChatTemplate + template = ChatTemplate(**template_data) + if self.template_manager.add_template(template): + loaded_count += 1 + else: + # Template exists, ask if user wants to update + if messagebox.askyesno("Template Exists", + f"Template '{name}' already exists. Update it?"): + self.template_manager.update_template(template) + loaded_count += 1 + except Exception as e: + print(f"Error loading template '{name}': {e}") + + self._refresh_chat_templates() + messagebox.showinfo("Templates Loaded", f"Successfully loaded {loaded_count} template(s)") + + except Exception as e: + messagebox.showerror("Error", f"Error loading templates: {e}") + + def _add_chat_template(self): + """Add a new chat template""" + try: + dialog = ChatTemplateDialog(self.root) + self.root.wait_window(dialog.dialog) + + if dialog.result: + template = dialog.result + if self.template_manager.add_template(template): + self._refresh_chat_templates() + self.chat_template_var.set(template.name) + messagebox.showinfo("Success", f"Template '{template.name}' added successfully") + else: + # Template exists, ask if user wants to update + if messagebox.askyesno("Template Exists", + f"Template '{template.name}' already exists. Update it?"): + self.template_manager.update_template(template) + self._refresh_chat_templates() + messagebox.showinfo("Success", f"Template '{template.name}' updated successfully") + + except Exception as e: + messagebox.showerror("Error", f"Error adding template: {e}") + + def _clear_kv_cache(self): + """Clear the KV cache for the current chat session.""" + try: + # Clear cache for unified models (HuggingFace) + if hasattr(chat_stream, '_unified_model_cache') and chat_stream._unified_model_cache: + llm = chat_stream._unified_model_cache + if hasattr(llm, 'clear_session_cache'): + llm.clear_session_cache(self._session_id) + self._append_output(f"[KV Cache cleared for session {self._session_id}]\n") + + # Show cache statistics after clearing + if hasattr(llm, 'get_kv_cache_stats'): + cache_stats = llm.get_kv_cache_stats() + self._append_output(f"[Cache Stats] Active sessions: {cache_stats.get('active_sessions', 0)}\n") + else: + self._append_output("[KV Cache clear not supported for this model]\n") + else: + self._append_output("[No cached model found - KV cache already clear]\n") + except Exception as e: + self._append_output(f"[Error clearing KV cache: {e}]\n") + + def _open_hf_downloader(self): + """Open standalone HuggingFace downloader window.""" + try: + from hf_downloader import HuggingFaceDownloaderGUI + downloader_window = tk.Toplevel(self.root) + HuggingFaceDownloaderGUI(downloader_window) + except Exception as e: + messagebox.showerror("Error", f"Failed to open HuggingFace downloader: {e}") + + def _show_about(self): + """Show about dialog.""" + about_text = ( + "LLM_Train - Advanced Local Model Manager\n\n" + "A comprehensive local GGUF model runner with cloud integration\n\n" + "Features:\n" + "• Run local GGUF models with optimized performance\n" + "• Search and download from HuggingFace Hub\n" + "• Advanced download manager with pause/resume/retry\n" + "• Model Library with smart scanning and indexing\n" + "• Multi-model MCP server for Claude integration\n" + "• Organization support for HuggingFace teams\n" + "• Chat and single prompt modes\n" + "• Customizable settings and preferences\n" + "• Optimized USB/SSD write speeds\n\n" + "Powered by llama-cpp-python and MCP protocol" + ) + messagebox.showinfo("About", about_text) + + def _clear_completed_downloads(self): + """Clear completed downloads from download manager.""" + if hasattr(self, 'download_tab'): + self.download_tab._clear_completed() + + def _open_mcp_config(self): + """Open MCP server configuration.""" + try: + open_mcp_config(self.root) + except Exception as e: + messagebox.showerror("Error", f"Failed to open MCP configuration: {e}") + + def _open_resource_monitor(self): + """Open Resource Monitor in a new window.""" + resource_window = tk.Toplevel(self.root) + resource_window.title("Resource Monitor") + resource_window.geometry("800x600") + resource_window.transient(self.root) + + # Build resource monitor content in the new window + self._build_resource_tab(resource_window) + + def _open_downloads_manager(self): + """Open Downloads Manager in a new window.""" + downloads_window = tk.Toplevel(self.root) + downloads_window.title("Downloads Manager") + downloads_window.geometry("900x700") + downloads_window.transient(self.root) + + # Create download manager in the new window + from download_manager_tab import GroupedDownloadManagerTab + GroupedDownloadManagerTab(downloads_window, self.download_manager) + + def _open_mcp_server(self): + """Open MCP Server in a new window.""" + mcp_window = tk.Toplevel(self.root) + mcp_window.title("MCP Server") + mcp_window.geometry("800x600") + mcp_window.transient(self.root) + + # Create MCP server tab in the new window + MCPTab(mcp_window, self.settings_manager) + + def _open_dhal_agent(self): + """Open Dhal Dark Agent in a new window.""" + dhal_window = tk.Toplevel(self.root) + dhal_window.title("Dhal - Dark Agent") + dhal_window.geometry("1000x700") + dhal_window.transient(self.root) + + # Create Dark Agent tab in the new window with proper main_app reference + from dark_agent import DarkAgentTab + dark_agent_tab = DarkAgentTab(dhal_window, self.settings_manager, self) + # Ensure the main_app reference is properly set + dark_agent_tab.main_app = self + + def _open_adk(self): + """Open Agent Development Kit in a new window.""" + adk_window = tk.Toplevel(self.root) + adk_window.title("Agent Development Kit (ADK)") + adk_window.geometry("900x600") + adk_window.transient(self.root) + + # Create ADK interface + adk_frame = ttk.Frame(adk_window) + adk_frame.pack(fill=tk.BOTH, expand=True, padx=20, pady=20) + + ttk.Label(adk_frame, text="Agent Development Kit", font=("Arial", 16, "bold")).pack(pady=10) + ttk.Label(adk_frame, text="Advanced tools for creating and managing AI agents").pack(pady=5) + ttk.Label(adk_frame, text="Coming Soon...", font=("Arial", 12, "italic")).pack(pady=20) + + def _open_metasploit(self): + """Open Metasploit interface in a new window.""" + metasploit_window = tk.Toplevel(self.root) + metasploit_window.title("Metasploit") + metasploit_window.geometry("1000x700") + metasploit_window.transient(self.root) + + # Create Metasploit interface + from pentestgpt import PentestGPTTab + PentestGPTTab(metasploit_window, self.settings_manager, self) + + def _inspect_model_devices(self): + """Open device inspection dialog""" + from tools.inspect_devices import inspect_loaded_model, inspect_model_devices + + # Check if we have a loaded model + current_model_path = self.model_var.get() + if not current_model_path or current_model_path == "Select a model...": + tk.messagebox.showwarning("No Model", "Please load a model first.") + return + + # Create inspection window + inspect_window = tk.Toplevel(self.root) + inspect_window.title("Model Device Inspection") + inspect_window.geometry("800x600") + inspect_window.configure(bg='#2b2b2b') + + # Create text widget with scrollbar + frame = tk.Frame(inspect_window, bg='#2b2b2b') + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + text_widget = tk.Text(frame, bg='#1e1e1e', fg='#ffffff', font=('Consolas', 10), wrap=tk.WORD) + scrollbar = tk.Scrollbar(frame, orient=tk.VERTICAL, command=text_widget.yview) + text_widget.configure(yscrollcommand=scrollbar.set) + + text_widget.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # Add inspection results + text_widget.insert(tk.END, "Inspecting model devices...\n\n") + text_widget.update() + + try: + # Check if we have a loaded unified model + if hasattr(chat_stream, '_unified_model_cache') and chat_stream._unified_model_cache: + result = inspect_loaded_model(chat_stream._unified_model_cache.model) + else: + # Fallback: load and inspect the model path + result = inspect_model_devices(current_model_path) + + text_widget.delete('1.0', tk.END) + text_widget.insert(tk.END, result) + except Exception as e: + text_widget.delete('1.0', tk.END) + text_widget.insert(tk.END, f"Error inspecting model:\n{str(e)}") + + text_widget.config(state=tk.DISABLED) + + def _on_closing(self): + """Handle window closing event.""" + # Save current window size + self.settings_manager.set('ui_preferences.window_width', self.root.winfo_width()) + self.settings_manager.set('ui_preferences.window_height', self.root.winfo_height()) + + # Save other current values + if self.model_var.get(): + self.settings_manager.set('paths.last_model_path', self.model_var.get()) + if self.lora_var.get(): + self.settings_manager.set('paths.last_lora_path', self.lora_var.get()) + + self.settings_manager.save_settings() + self.root.destroy() + + def _build_run_tab(self, frame: ttk.Frame): + # Create notebook for Run sub-tabs + run_notebook = ttk.Notebook(frame) + run_notebook.pack(fill=tk.BOTH, expand=True) + + # Chat sub-tab + self.chat_frame = ttk.Frame(run_notebook) + run_notebook.add(self.chat_frame, text="Chat") + self._build_chat_subtab(self.chat_frame) + + # Model Settings sub-tab + self.model_settings_frame = ttk.Frame(run_notebook) + run_notebook.add(self.model_settings_frame, text="Model Settings") + self._build_model_settings_tab(self.model_settings_frame) + + def _build_chat_subtab(self, frame: ttk.Frame): + # Model loading section + model_frame = ttk.LabelFrame(frame, text="Model Selection", padding="10") + model_frame.pack(fill=tk.X, padx=8, pady=8) + + ttk.Label(model_frame, text="Model:").grid(row=0, column=0, sticky=tk.W) + self.model_entry = ttk.Entry(model_frame, textvariable=self.model_var, width=50) + self.model_entry.grid(row=0, column=1, sticky=tk.EW, padx=5) + self.browse_model_btn = ttk.Button(model_frame, text="Browse Model", command=self._browse_gguf) + self.browse_model_btn.grid(row=0, column=2, padx=2) + self.browse_folder_btn = ttk.Button(model_frame, text="Browse Folder", command=self._browse_folder) + self.browse_folder_btn.grid(row=0, column=3, padx=2) + self.load_model_btn = ttk.Button(model_frame, text="Load Model", command=self._on_load_unload_model) + self.load_model_btn.grid(row=0, column=4, padx=5) + + # Chat Template row + ttk.Label(model_frame, text="Chat Template:").grid(row=1, column=0, sticky=tk.W, pady=(5, 0)) + self.chat_template_combo = ttk.Combobox(model_frame, textvariable=self.chat_template_var, + values=["None"], state="readonly", width=20) + self.chat_template_combo.grid(row=1, column=1, sticky=tk.W, padx=5, pady=(5, 0)) + ttk.Button(model_frame, text="Load", command=self._load_chat_template).grid(row=1, column=2, padx=2, pady=(5, 0)) + ttk.Button(model_frame, text="Add", command=self._add_chat_template).grid(row=1, column=3, padx=2, pady=(5, 0)) + + status_frame = ttk.Frame(model_frame) + status_frame.grid(row=2, column=0, columnspan=5, sticky=tk.W, pady=(5, 0)) + ttk.Label(status_frame, textvariable=self.model_status_var).pack(side=tk.LEFT) + ttk.Label(status_frame, text=" | Supports: GGUF, SafeTensors, GPTQ, AWQ, EXL2, PyTorch", + font=('Arial', 8), foreground='gray').pack(side=tk.LEFT, padx=(10, 0)) + + # Configure grid weights for resizing + model_frame.grid_columnconfigure(1, weight=1) + + # Options section + options_frame = ttk.Frame(frame) + options_frame.pack(fill=tk.X, padx=8, pady=(0, 8)) + + ttk.Checkbutton(options_frame, text="Chess Mode (ChessGPT)", variable=self.chess_mode_var, + command=self._on_chess_mode_changed).pack(side=tk.LEFT) + ttk.Checkbutton(options_frame, text="Stream Output", variable=self.stream_var).pack(side=tk.LEFT, padx=(20, 0)) + + # Agent Mode controls + self.agent_mode_var = tk.BooleanVar(value=False) + agent_btn = ttk.Checkbutton(options_frame, text="🤖 Agent Mode (SYSTEM ACCESS)", + variable=self.agent_mode_var, + command=self._on_agent_mode_changed) + agent_btn.pack(side=tk.LEFT, padx=(20, 0)) + + # Initialize agent handler + self.agent_handler = None + self._init_agent_mode() + + mid = ttk.Frame(frame) + mid.pack(fill=tk.BOTH, expand=True, padx=8, pady=4) + ttk.Label(mid, text="Prompt / Chat Input:").pack(anchor=tk.W) + self.prompt_text = tk.Text(mid, height=6) + self.prompt_text.pack(fill=tk.BOTH, expand=True) + + btns = ttk.Frame(frame) + btns.pack(fill=tk.X, padx=8, pady=4) + self.send_btn = ttk.Button(btns, text="Send (Chat)", command=self._on_chat) + self.send_btn.pack(side=tk.LEFT) + self.stop_btn = ttk.Button(btns, text="Stop", command=self._on_stop, state="disabled") + self.stop_btn.pack(side=tk.LEFT, padx=6) + ttk.Button(btns, text="Clear Output", command=lambda: self.output_text.delete('1.0', tk.END)).pack(side=tk.LEFT, + padx=6) + + out = ttk.Frame(frame) + out.pack(fill=tk.BOTH, expand=True, padx=8, pady=4) + ttk.Label(out, text="Output:").pack(anchor=tk.W) + self.output_text = tk.Text(out, height=12) + self.output_text.pack(fill=tk.BOTH, expand=True) + + def _add_tooltip(self, widget, text): + """Add a tooltip to a widget""" + def create_tooltip(widget, text): + def on_enter(event): + # Prevent multiple tooltips + if hasattr(widget, 'tooltip') and widget.tooltip: + return + + try: + tooltip = tk.Toplevel() + tooltip.wm_overrideredirect(True) + x = widget.winfo_rootx() + 20 + y = widget.winfo_rooty() + 20 + tooltip.wm_geometry(f"+{x}+{y}") + label = tk.Label(tooltip, text=text, background="lightyellow", + relief="solid", borderwidth=1, font=("Arial", "9", "normal")) + label.pack() + widget.tooltip = tooltip + except: + # Ignore tooltip creation errors + pass + + def on_leave(event): + try: + if hasattr(widget, 'tooltip') and widget.tooltip: + widget.tooltip.destroy() + widget.tooltip = None + except: + # Ignore tooltip destruction errors + pass + + widget.bind("", on_enter) + widget.bind("", on_leave) + + create_tooltip(widget, text) + + def _build_model_settings_tab(self, frame: ttk.Frame): + # Context and GPU settings + ctx_frame = ttk.LabelFrame(frame, text="Context & GPU Settings", padding="10") + ctx_frame.pack(fill=tk.X, padx=8, pady=8) + + # Auto-config checkbox + self.auto_context_var = tk.BooleanVar(value=self.settings_manager.get('model_settings.auto_context', True)) + self.auto_context_check = ttk.Checkbutton( + ctx_frame, + text="Auto-configure context size based on model", + variable=self.auto_context_var, + command=self._on_auto_context_changed + ) + self.auto_context_check.grid(row=0, column=0, columnspan=3, sticky=tk.W, pady=(0, 10)) + self._add_tooltip(self.auto_context_check, "Automatically use the model's trained context size (n_ctx_train) for optimal performance.\nDisable to manually set context size.") + + ttk.Label(ctx_frame, text="Context Length (n_ctx):").grid(row=1, column=0, sticky=tk.W) + self.n_ctx_spin = ttk.Entry(ctx_frame, textvariable=self.n_ctx_var, width=15) + self.n_ctx_spin.grid(row=1, column=1, sticky=tk.W, padx=5) + self._add_tooltip(self.n_ctx_spin, "Maximum number of tokens the model can process at once.\nHigher values use more memory but allow longer conversations.") + ttk.Label(ctx_frame, text="tokens").grid(row=1, column=2, sticky=tk.W) + + # Disable manual entry if auto-config is enabled + if self.auto_context_var.get(): + self.n_ctx_spin.configure(state='disabled') + + # Auto-GPU config checkbox + self.auto_gpu_var = tk.BooleanVar(value=self.settings_manager.get('model_settings.auto_gpu', True)) + self.auto_gpu_check = ttk.Checkbutton( + ctx_frame, + text="Auto-configure GPU layers for optimal performance", + variable=self.auto_gpu_var, + command=self._on_auto_gpu_changed + ) + self.auto_gpu_check.grid(row=2, column=0, columnspan=3, sticky=tk.W, pady=(10, 5)) + self._add_tooltip(self.auto_gpu_check, "Automatically set GPU layers based on your VRAM and model size.\nDisable to manually set GPU layers.") + + ttk.Label(ctx_frame, text="GPU Layers (n_gpu_layers):").grid(row=3, column=0, sticky=tk.W, pady=(5, 0)) + self.n_gpu_spin = ttk.Entry(ctx_frame, textvariable=self.n_gpu_layers_var, width=15) + self.n_gpu_spin.grid(row=3, column=1, sticky=tk.W, padx=5, pady=(5, 0)) + self._add_tooltip(self.n_gpu_spin, "Number of model layers to offload to GPU.\nHigher values improve speed but use more VRAM.\nUse 0 for CPU-only.") + ttk.Label(ctx_frame, text="layers").grid(row=3, column=2, sticky=tk.W, pady=(5, 0)) + + # Disable manual entry if auto-config is enabled + if self.auto_gpu_var.get(): + self.n_gpu_spin.configure(state='disabled') + + # LoRA settings + lora_frame = ttk.LabelFrame(frame, text="LoRA Adapter Settings", padding="10") + lora_frame.pack(fill=tk.X, padx=8, pady=8) + + ttk.Label(lora_frame, text="LoRA Path (optional):").grid(row=0, column=0, sticky=tk.W) + self.lora_entry = ttk.Entry(lora_frame, textvariable=self.lora_var, width=60) + self.lora_entry.grid(row=0, column=1, sticky=tk.EW, padx=5) + self._add_tooltip(self.lora_entry, "Path to LoRA (Low-Rank Adaptation) adapter file.\nLoRA adapters fine-tune model behavior without changing base weights.\nLeave empty if not using LoRA.") + self.lora_btn = ttk.Button(lora_frame, text="Browse", command=self._browse_lora) + self.lora_btn.grid(row=0, column=2, padx=5) + lora_frame.grid_columnconfigure(1, weight=1) + + # Generation settings + gen_frame = ttk.LabelFrame(frame, text="Generation Settings", padding="10") + gen_frame.pack(fill=tk.X, padx=8, pady=8) + + # Max tokens (renamed for clarity) + ttk.Label(gen_frame, text="Max New Tokens (n_predict):").grid(row=0, column=0, sticky=tk.W) + self.max_tokens_spin = tk.Spinbox(gen_frame, from_=16, to=8192, increment=16, textvariable=self.max_tokens_var, + width=15) + self.max_tokens_spin.grid(row=0, column=1, sticky=tk.W, padx=5) + self._add_tooltip(self.max_tokens_spin, "Maximum number of new tokens to generate.\nHigher values allow longer responses but take more time.") + + # Temperature + ttk.Label(gen_frame, text="Temperature:").grid(row=1, column=0, sticky=tk.W, pady=(10, 0)) + temp_spin = tk.Spinbox(gen_frame, from_=0.0, to=2.0, increment=0.1, + textvariable=self.temperature_var, width=15, format="%.1f") + temp_spin.grid(row=1, column=1, sticky=tk.W, padx=5, pady=(10, 0)) + self._add_tooltip(temp_spin, "Controls randomness in generation.\n0.0 = deterministic, 1.0 = balanced, 2.0 = very creative.\nLower values for factual tasks, higher for creative tasks.") + + # Top P + ttk.Label(gen_frame, text="Top P:").grid(row=2, column=0, sticky=tk.W, pady=(10, 0)) + top_p_spin = tk.Spinbox(gen_frame, from_=0.0, to=1.0, increment=0.1, + textvariable=self.top_p_var, width=15, format="%.1f") + top_p_spin.grid(row=2, column=1, sticky=tk.W, padx=5, pady=(10, 0)) + self._add_tooltip(top_p_spin, "Nucleus sampling parameter.\nOnly consider tokens in the top P probability mass.\n0.9 is typical, lower values for more focused responses.") + + # Repetition Penalty + ttk.Label(gen_frame, text="Repetition Penalty:").grid(row=3, column=0, sticky=tk.W, pady=(10, 0)) + rep_pen_spin = tk.Spinbox(gen_frame, from_=0.5, to=2.0, increment=0.1, + textvariable=self.repetition_penalty_var, width=15, format="%.1f") + rep_pen_spin.grid(row=3, column=1, sticky=tk.W, padx=5, pady=(10, 0)) + self._add_tooltip(rep_pen_spin, "Penalty for repeating tokens.\n1.0 = no penalty, >1.0 = discourage repetition.\n1.1 is typical, higher values reduce repetition more.") + + # No Repeat N-gram Size + ttk.Label(gen_frame, text="No Repeat N-gram Size:").grid(row=4, column=0, sticky=tk.W, pady=(10, 0)) + ngram_spin = tk.Spinbox(gen_frame, from_=0, to=10, increment=1, + textvariable=self.no_repeat_ngram_size_var, width=15) + ngram_spin.grid(row=4, column=1, sticky=tk.W, padx=5, pady=(10, 0)) + self._add_tooltip(ngram_spin, "Prevent repeating N-grams (sequences of N tokens).\n0 = disabled, 2-4 = typical values.\nHigher values prevent more repetitive patterns.") + + # Min P + ttk.Label(gen_frame, text="Min P:").grid(row=5, column=0, sticky=tk.W, pady=(10, 0)) + min_p_spin = tk.Spinbox(gen_frame, from_=0.0, to=1.0, increment=0.01, + textvariable=self.min_p_var, width=15, format="%.2f") + min_p_spin.grid(row=5, column=1, sticky=tk.W, padx=5, pady=(10, 0)) + self._add_tooltip(min_p_spin, "Minimum probability threshold.\nTokens below this probability are excluded.\n0.0 = disabled, 0.05 = typical value.") + + # Typical P + ttk.Label(gen_frame, text="Typical P:").grid(row=6, column=0, sticky=tk.W, pady=(10, 0)) + typical_p_spin = tk.Spinbox(gen_frame, from_=0.0, to=1.0, increment=0.1, + textvariable=self.typical_p_var, width=15, format="%.1f") + typical_p_spin.grid(row=6, column=1, sticky=tk.W, padx=5, pady=(10, 0)) + self._add_tooltip(typical_p_spin, "Typical sampling parameter.\nFocuses on tokens with 'typical' information content.\n1.0 = disabled, 0.95 = typical value.") + + # Advanced Loading Options + advanced_frame = ttk.LabelFrame(frame, text="Advanced Loading Options", padding="10") + advanced_frame.pack(fill=tk.X, padx=8, pady=8) + + # Quantization options + ttk.Label(advanced_frame, text="Quantization:").grid(row=0, column=0, sticky=tk.W) + self.quantization_combo = ttk.Combobox(advanced_frame, textvariable=self.quantization_var, + values=["none", "4bit", "8bit", "gptq", "awq", "exl2"], state="readonly", width=20) + self.quantization_combo.grid(row=0, column=1, sticky=tk.W, padx=5) + self._add_tooltip(self.quantization_combo, "Reduce model memory usage by using lower precision.\nnone = full precision\n4bit/8bit = bitsandbytes quantization\ngptq/awq/exl2 = specialized quantization formats") + ttk.Label(advanced_frame, text="(auto-detected for pre-quantized models)").grid(row=0, column=2, sticky=tk.W, padx=(10, 0)) + + # Device strategy + ttk.Label(advanced_frame, text="Device Strategy:").grid(row=1, column=0, sticky=tk.W, pady=(10, 0)) + self.device_combo = ttk.Combobox(advanced_frame, textvariable=self.device_strategy_var, + values=["auto", "force_gpu", "balanced_split", "cpu_only"], + state="readonly", width=20) + self.device_combo.grid(row=1, column=1, sticky=tk.W, padx=5, pady=(10, 0)) + self._add_tooltip(self.device_combo, "How to distribute model across devices.\nauto = automatic distribution\nforce_gpu = all on GPU\nbalanced_split = split between CPU/GPU\ncpu_only = CPU only") + ttk.Label(advanced_frame, text="(balanced_split for large models)").grid(row=1, column=2, sticky=tk.W, padx=(10, 0), pady=(10, 0)) + + # GPU memory limit + ttk.Label(advanced_frame, text="GPU Memory Limit:").grid(row=2, column=0, sticky=tk.W, pady=(10, 0)) + self.gpu_mem_spin = tk.Spinbox(advanced_frame, from_=1.0, to=24.0, increment=0.5, + textvariable=self.gpu_memory_limit_var, width=15, format="%.1f") + self.gpu_mem_spin.grid(row=2, column=1, sticky=tk.W, padx=5, pady=(10, 0)) + self._add_tooltip(self.gpu_mem_spin, "Maximum GPU memory to use (in GB).\nUsed with balanced_split strategy.\nSet below your GPU's total VRAM to leave room for other applications.") + ttk.Label(advanced_frame, text="GB (for balanced_split)").grid(row=2, column=2, sticky=tk.W, padx=(10, 0), pady=(10, 0)) + + # Local models browser + local_frame = ttk.LabelFrame(frame, text="Local Models Browser", padding="10") + local_frame.pack(fill=tk.X, padx=8, pady=8) + + ttk.Label(local_frame, text="Local Models:").grid(row=0, column=0, sticky=tk.W) + self.local_models_combo = ttk.Combobox(local_frame, textvariable=self.local_model_var, width=50, + state="readonly") + self.local_models_combo.grid(row=0, column=1, sticky=tk.EW, padx=5) + self.local_models_combo.bind("<>", self._on_local_model_selected) + ttk.Button(local_frame, text="Folder...", command=self._choose_models_folder).grid(row=0, column=2, padx=5) + ttk.Button(local_frame, text="Refresh", command=self._refresh_local_models).grid(row=0, column=3, padx=5) + local_frame.grid_columnconfigure(1, weight=1) + + def _build_resource_tab(self, frame: ttk.Frame): + # GPU Information + gpu_frame = ttk.LabelFrame(frame, text="GPU Information", padding="10") + gpu_frame.pack(fill=tk.X, padx=8, pady=8) + + self.gpu_info_var = tk.StringVar(value="Checking GPU...") + self.gpu_memory_var = tk.StringVar(value="Memory: Unknown") + self.gpu_usage_var = tk.StringVar(value="Usage: Unknown") + + ttk.Label(gpu_frame, textvariable=self.gpu_info_var).grid(row=0, column=0, sticky=tk.W, columnspan=3) + ttk.Label(gpu_frame, textvariable=self.gpu_memory_var).grid(row=1, column=0, sticky=tk.W, pady=(5, 0)) + ttk.Label(gpu_frame, textvariable=self.gpu_usage_var).grid(row=2, column=0, sticky=tk.W, pady=(5, 0)) + ttk.Button(gpu_frame, text="Test GPU", command=self._test_gpu).grid(row=1, column=2, padx=5, rowspan=2) + + # CPU Information + cpu_frame = ttk.LabelFrame(frame, text="CPU Information", padding="10") + cpu_frame.pack(fill=tk.X, padx=8, pady=8) + + self.cpu_info_var = tk.StringVar(value="Detecting CPU...") + self.cpu_usage_var = tk.StringVar(value="Usage: Unknown") + self.ram_usage_var = tk.StringVar(value="RAM: Unknown") + + ttk.Label(cpu_frame, textvariable=self.cpu_info_var).grid(row=0, column=0, sticky=tk.W, columnspan=2) + ttk.Label(cpu_frame, textvariable=self.cpu_usage_var).grid(row=1, column=0, sticky=tk.W, pady=(5, 0)) + ttk.Label(cpu_frame, textvariable=self.ram_usage_var).grid(row=2, column=0, sticky=tk.W, pady=(5, 0)) + + # Resource monitoring controls + controls_frame = ttk.LabelFrame(frame, text="Monitoring Controls", padding="10") + controls_frame.pack(fill=tk.X, padx=8, pady=8) + + self.monitor_var = tk.BooleanVar(value=False) + ttk.Checkbutton(controls_frame, text="Enable Real-time Monitoring", variable=self.monitor_var, + command=self._toggle_monitoring).pack(side=tk.LEFT) + ttk.Button(controls_frame, text="Refresh Now", command=self._refresh_resources).pack(side=tk.LEFT, padx=(20, 0)) + + # Initialize resource monitoring + self._initialize_resource_monitoring() + + # ---------------- HuggingFace Browser Tab ----------------- + def _build_hf_tab(self, frame: ttk.Frame): + # Import the new HuggingFace downloader module + try: + from hf_downloader import HuggingFaceAPI + # Configure API based on settings + api_key = None + organization = None + + if not self.settings_manager.get('api.use_env_token', True): + api_key = self.settings_manager.get('api.huggingface_token', '').strip() + + if self.settings_manager.get('api.use_organization', False): + organization = self.settings_manager.get('api.organization', '').strip() + + self.hf_api = HuggingFaceAPI(api_key=api_key, organization=organization) + except ImportError: + ttk.Label(frame, + text="hf_downloader module not found. Please ensure hf_downloader.py is in the same directory.").pack( + padx=8, pady=8) + return + except ValueError as e: + ttk.Label(frame, text=f"API Key Error: {e}").pack(padx=8, pady=8) + return + except Exception as e: + ttk.Label(frame, text=f"Error initializing HuggingFace API: {e}").pack(padx=8, pady=8) + return + + # Search bar with dropdown (using settings defaults) + search_row = ttk.Frame(frame) + search_row.pack(fill=tk.X, padx=8, pady=8) + self.hf_search_query = tk.StringVar() + self.hf_search_type = tk.StringVar( + value=self.settings_manager.get('search_preferences.default_search_type', 'Models')) + + # Search entry + self.hf_search_entry = ttk.Entry(search_row, textvariable=self.hf_search_query, width=60) + self.hf_search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) + self.hf_search_entry.bind("", lambda e: self._hf_search()) + + # Search type dropdown + self.hf_type_combo = ttk.Combobox(search_row, values=["Models", "Datasets"], + textvariable=self.hf_search_type, state="readonly", width=15) + self.hf_type_combo.pack(side=tk.LEFT, padx=(10, 0)) + + # Search button + ttk.Button(search_row, text="Search", command=self._hf_search).pack(side=tk.LEFT, padx=(10, 0)) + + # Results area with enhanced columns + results_frame = ttk.Frame(frame) + results_frame.pack(fill=tk.BOTH, expand=True, padx=8, pady=(0, 8)) + + # Create treeview with new column structure + cols = ("creator", "name", "description", "keywords", "size", "metadata") + self.hf_tree = ttk.Treeview(results_frame, columns=cols, show="headings", height=15) + + # Define column headings and widths + self.hf_tree.heading("creator", text="Creator") + self.hf_tree.heading("name", text="Name") + self.hf_tree.heading("description", text="Description") + self.hf_tree.heading("keywords", text="Keywords") + self.hf_tree.heading("size", text="Size") + self.hf_tree.heading("metadata", text="Metadata") + + self.hf_tree.column("creator", width=120) + self.hf_tree.column("name", width=200) + self.hf_tree.column("description", width=250) + self.hf_tree.column("keywords", width=150) + self.hf_tree.column("size", width=80) + self.hf_tree.column("metadata", width=150) + + self.hf_tree.pack(fill=tk.BOTH, expand=True, side=tk.LEFT) + self.hf_tree.bind("", self._hf_download_selected) + + # Scrollbars + vsb = ttk.Scrollbar(results_frame, orient="vertical", command=self.hf_tree.yview) + hsb = ttk.Scrollbar(results_frame, orient="horizontal", command=self.hf_tree.xview) + self.hf_tree.configure(yscrollcommand=vsb.set, xscrollcommand=hsb.set) + vsb.pack(side=tk.RIGHT, fill=tk.Y) + + # Filter footer with checkboxes + filter_frame = ttk.Frame(frame) + filter_frame.pack(fill=tk.X, padx=8, pady=(0, 8)) + + ttk.Label(filter_frame, text="Filter:").pack(side=tk.LEFT) + + # Initialize filters based on default sort preference + default_sort = self.settings_manager.get('search_preferences.default_sort', 'downloads') + self.filter_most_downloaded = tk.BooleanVar(value=(default_sort == 'downloads')) + self.filter_most_liked = tk.BooleanVar(value=(default_sort == 'likes')) + self.filter_size = tk.BooleanVar(value=(default_sort == 'lastModified')) + + ttk.Checkbutton(filter_frame, text="Most Downloaded", + variable=self.filter_most_downloaded).pack(side=tk.LEFT, padx=(10, 0)) + ttk.Checkbutton(filter_frame, text="Most Liked", + variable=self.filter_most_liked).pack(side=tk.LEFT, padx=(10, 0)) + ttk.Checkbutton(filter_frame, text="Size", + variable=self.filter_size).pack(side=tk.LEFT, padx=(10, 0)) + + # Download button on the right + ttk.Button(filter_frame, text="Download Selected", + command=self._hf_download_selected).pack(side=tk.RIGHT, padx=5) + + # Status area for HF tab + self.hf_status_var = tk.StringVar(value="Ready") + ttk.Label(frame, textvariable=self.hf_status_var).pack(fill=tk.X, padx=8, pady=(0, 8)) + + # Holder for last results (to act on selection) + self._hf_results: List[Dict[str, Any]] = [] + + def _hf_set_status(self, text: str): + try: + self.hf_status_var.set(text) + except Exception: + pass + + def _format_bytes(self, n: Optional[int]) -> str: + try: + if not n or n <= 0: + return "-" + units = ["B", "KB", "MB", "GB", "TB"] + i = 0 + f = float(n) + while f >= 1024 and i < len(units) - 1: + f /= 1024.0 + i += 1 + return f"{f:.1f} {units[i]}" + except Exception: + return "-" + + def _format_number(self, num: int) -> str: + """Format large numbers with K, M suffixes.""" + if num >= 1_000_000: + return f"{num / 1_000_000:.1f}M" + elif num >= 1_000: + return f"{num / 1_000:.1f}K" + return str(num) + + def _hf_search(self): + query = (self.hf_search_query.get() or "").strip() + search_type = self.hf_search_type.get() + + # Determine sort parameter based on filters + sort = "downloads" + if self.filter_most_liked.get() and not self.filter_most_downloaded.get(): + sort = "likes" + elif self.filter_size.get() and not self.filter_most_downloaded.get() and not self.filter_most_liked.get(): + sort = "lastModified" + + self._hf_set_status("Searching ...") + self.hf_tree.delete(*self.hf_tree.get_children()) + self._hf_results = [] + + threading.Thread(target=self._hf_perform_search_thread, args=(search_type, query, sort), daemon=True).start() + + def _hf_perform_search_thread(self, search_type: str, query: str, sort: str): + try: + # Use the new API + if search_type == "Models": + results = self.hf_api.search_models(query, limit=50, sort=sort) + else: + results = self.hf_api.search_datasets(query, limit=50, sort=sort) + + rows = [] + packed = [] + + for item in results: + try: + # Extract fields based on search type + if search_type == "Models": + repo_id = item.get("modelId", item.get("id", "")) + pipeline_tag = item.get("pipeline_tag", "") + tags = item.get("tags", []) + keywords = ", ".join(tags[:3]) if tags else pipeline_tag + description = item.get("description", "") + else: + repo_id = item.get("id", "") + task_ids = item.get("cardData", {}).get("task_ids", []) + keywords = ", ".join(task_ids[:3]) if task_ids else "dataset" + card_data = item.get("cardData", {}) + description = card_data.get("description", card_data.get("summary", "")) + + creator = repo_id.split("/")[0] if "/" in repo_id else "" + name = repo_id.split("/")[1] if "/" in repo_id else repo_id + + # Truncate description + if len(description) > 100: + description = description[:97] + "..." + + # Calculate size + size_bytes = 0 + siblings = item.get("siblings", []) + for sibling in siblings: + if isinstance(sibling, dict): + size = sibling.get("size", 0) + if isinstance(size, (int, float)): + size_bytes += size + + size_str = self._format_bytes(size_bytes) if size_bytes > 0 else "-" + + # Get metadata + metadata_parts = [] + downloads = item.get("downloads", 0) + likes = item.get("likes", 0) + + if downloads > 0: + metadata_parts.append(f"↓{self._format_number(downloads)}") + if likes > 0: + metadata_parts.append(f"♥{self._format_number(likes)}") + + if search_type == "Models": + library = item.get("library_name", "") + if library: + metadata_parts.append(library) + + metadata = " | ".join(metadata_parts) + + # Prepare row + rows.append((creator, name, description, keywords, size_str, metadata)) + packed.append({ + "type": search_type, + "repo_id": repo_id, + }) + except Exception: + continue + + def apply_rows(): + try: + for row in rows: + self.hf_tree.insert("", tk.END, values=row) + self._hf_results = packed + self._hf_set_status(f"Found {len(rows)} {search_type.lower()}") + except Exception: + pass + + self._enqueue_ui(apply_rows) + except Exception as e: + self._enqueue_ui(lambda: self._hf_set_status(f"Search error: {e}")) + + def _hf_download_selected(self, event=None): + """Download the selected model or dataset using grouped download manager.""" + selection = self.hf_tree.selection() + if not selection: + messagebox.showinfo("No Selection", "Please select an item to download") + return + + item = self.hf_tree.item(selection[0]) + values = item['values'] + + if len(values) < 2: + return + + creator = values[0] + name = values[1] + repo_id = f"{creator}/{name}" if creator else name + + # Ask for download location (use default from settings) + initial_dir = self.settings_manager.get('paths.downloads_directory', './downloads') + download_dir = filedialog.askdirectory( + title="Select Download Directory", + initialdir=initial_dir + ) + if not download_dir: + return + + self._hf_set_status(f"Fetching file list for {repo_id}...") + + def prepare_downloads(): + try: + # Get file list + files = self.hf_api.get_model_files(repo_id) + + if not files: + self._enqueue_ui(lambda: self._append_output(f"[Error] No files found for {repo_id}\n")) + self._enqueue_ui(lambda: self._hf_set_status("No files found")) + return + + # Always show file selection dialog for user choice + self._enqueue_ui(lambda: self._show_file_selection_dialog(repo_id, files, download_dir)) + + except Exception as e: + self._enqueue_ui(lambda: self._append_output(f"[Download error] {e}\n")) + self._enqueue_ui(lambda: self._hf_set_status("Download preparation failed")) + + threading.Thread(target=prepare_downloads, daemon=True).start() + + def _show_file_selection_dialog(self, repo_id: str, files: List[Dict], download_dir: str): + """Show dialog to select files for download using the new FileSelectionDialog.""" + try: + # Use the new FileSelectionDialog + dialog = FileSelectionDialog( + parent=self.root, + repo_id=repo_id, + files=files, + title=f"Select Files to Download - {repo_id}" + ) + + result, selected_files = dialog.show() + + if result == 'download' and selected_files: + # Create download group + group_name = f"{repo_id.split('/')[-1] if '/' in repo_id else repo_id}" + group_description = f"Files from {repo_id}" + + group_id = self.download_manager.create_download_group( + repo_id=repo_id, + name=group_name, + description=group_description + ) + + # Add selected files to the group + download_count = 0 + for filename, file_info in selected_files: + url = f"{self.hf_api.base_url}/{repo_id}/resolve/main/{filename}" + save_path = os.path.join(download_dir, repo_id.replace("/", "_"), filename) + + self.download_manager.add_file_to_group( + group_id=group_id, + filename=filename, + url=url, + save_path=save_path, + headers=self.hf_api.headers, + selected=True + ) + download_count += 1 + + self._hf_set_status(f"Added {download_count} file(s) to download queue") + self._append_output(f"Created download group '{group_name}' with {download_count} files\n") + + # Switch to downloads tab to show the new group + self.notebook.select(self.downloads_frame) + else: + self._hf_set_status("Download cancelled") + + except Exception as e: + self._append_output(f"[Error] Failed to show file selection dialog: {e}\n") + self._hf_set_status("Error showing file selection") + + def _browse_gguf(self): + initial_dir = self.settings_manager.get('paths.models_directory', '.') + path = filedialog.askopenfilename( + title="Select Model (GGUF, Safetensors, GPTQ, AWQ, EXL2)", + initialdir=initial_dir, + filetypes=[ + ("All Model files", "*.gguf;*.safetensors;*.bin;*.pt;*.pth;*.exl2"), + ("GGUF files", "*.gguf"), + ("SafeTensors files", "*.safetensors"), + ("PyTorch files", "*.bin;*.pt;*.pth"), + ("GPTQ models", "*gptq*.safetensors;*gptq*.bin"), + ("AWQ models", "*awq*.safetensors;*awq*.bin"), + ("EXL2 files", "*.exl2"), + ("All files", "*.*") + ] + ) + if path: + self.model_var.set(path) + # Save last model path + self.settings_manager.set('paths.last_model_path', path) + self.settings_manager.save_settings() + + def _browse_folder(self): + initial_dir = self.settings_manager.get('paths.models_directory', '.') + path = filedialog.askdirectory( + title="Select Model Directory (HuggingFace format)", + initialdir=initial_dir + ) + if path: + self.model_var.set(path) + # Save last model path + self.settings_manager.set('paths.last_model_path', path) + self.settings_manager.save_settings() + + def _browse_lora(self): + path = filedialog.askopenfilename(title="Select LoRA/adapter file", filetypes=[("All files", "*.*")]) + if path: + self.lora_var.set(path) + + def _mark_model_unloaded(self): + """Mark model as unloaded and clear references""" + # Clear model references + if hasattr(self, 'current_model'): + self.current_model = None + if hasattr(chat_stream, '_unified_model_cache'): + chat_stream._unified_model_cache = None + + # Clear any cached models + global _llama_cache + _llama_cache["key"] = None + _llama_cache["llm"] = None + + # Clear agent reference + if hasattr(self, 'dhal_agent'): + self.dhal_agent = None + + # Update UI + self.model_status_var.set("[not loaded]") + self._update_load_button_text() + + # Clear chat history since model is unloaded + self.chat_history = [] + + def _disable_model_settings(self): + """Disable model settings that would unload the model""" + try: + # Disable model path entry and browse buttons + if hasattr(self, 'model_entry'): + self.model_entry.configure(state='disabled') + if hasattr(self, 'browse_model_btn'): + self.browse_model_btn.configure(state='disabled') + if hasattr(self, 'browse_folder_btn'): + self.browse_folder_btn.configure(state='disabled') + + # Disable settings that would trigger model reload + if hasattr(self, 'n_ctx_spin'): + self.n_ctx_spin.configure(state='disabled') + if hasattr(self, 'n_gpu_spin'): + self.n_gpu_spin.configure(state='disabled') + if hasattr(self, 'lora_entry'): + self.lora_entry.configure(state='disabled') + if hasattr(self, 'lora_btn'): + self.lora_btn.configure(state='disabled') + if hasattr(self, 'quantization_combo'): + self.quantization_combo.configure(state='disabled') + if hasattr(self, 'device_combo'): + self.device_combo.configure(state='disabled') + if hasattr(self, 'gpu_mem_spin'): + self.gpu_mem_spin.configure(state='disabled') + except Exception as e: + print(f"Error disabling model settings: {e}") + + def _enable_model_settings(self): + """Re-enable model settings""" + try: + # Re-enable model path entry and browse buttons + if hasattr(self, 'model_entry'): + self.model_entry.configure(state='normal') + if hasattr(self, 'browse_model_btn'): + self.browse_model_btn.configure(state='normal') + if hasattr(self, 'browse_folder_btn'): + self.browse_folder_btn.configure(state='normal') + + # Re-enable settings + if hasattr(self, 'n_ctx_spin'): + self.n_ctx_spin.configure(state='normal') + if hasattr(self, 'n_gpu_spin'): + self.n_gpu_spin.configure(state='normal') + if hasattr(self, 'lora_entry'): + self.lora_entry.configure(state='normal') + if hasattr(self, 'lora_btn'): + self.lora_btn.configure(state='normal') + if hasattr(self, 'quantization_combo'): + self.quantization_combo.configure(state='readonly') + if hasattr(self, 'device_combo'): + self.device_combo.configure(state='readonly') + if hasattr(self, 'gpu_mem_spin'): + self.gpu_mem_spin.configure(state='normal') + except Exception as e: + print(f"Error enabling model settings: {e}") + + def _init_agent_mode(self): + """Initialize agent mode handler""" + try: + # Import agent components + sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'agent_dhal')) + from agent_dhal.hal import create_dhal, DhalConfig, HalModelClient + + self.agent_enabled = True + self.dhal_agent = None + except Exception as e: + print(f"[Agent Mode] Could not import Hal components: {e}") + self.agent_enabled = False + self.agent_mode_var.set(False) + + def _on_agent_mode_changed(self): + """Handle agent mode toggle""" + if self.agent_mode_var.get(): + # Show warning + result = messagebox.askyesno( + "⚠️ Enable Agent Mode", + "WARNING: Agent mode gives the AI UNRESTRICTED access to:\n\n" + "• Your file system (read/write/delete)\n" + "• Shell commands (PowerShell, Bash, CMD)\n" + "• Mouse and keyboard control\n" + "• Python code execution\n" + "• Network requests\n" + "• System settings\n\n" + "The AI can control your computer completely!\n\n" + "Only enable if you trust the model and understand the risks.\n\n" + "Continue?", + icon='warning' + ) + + if not result: + self.agent_mode_var.set(False) + return + + # Initialize agent if needed + if self.agent_enabled and not self.dhal_agent: + try: + # Check if we have a loaded model + if not (hasattr(self, 'current_model') and self.current_model) and \ + not (hasattr(chat_stream, '_unified_model_cache') and chat_stream._unified_model_cache): + # No model loaded - show error + messagebox.showerror( + "No Model Loaded", + "Please load a model first before enabling Agent Mode.\n\n" + "Agent Mode requires a loaded language model to function." + ) + self.agent_mode_var.set(False) + return + + # Use the existing loaded model + from agent_dhal.hal import create_dhal + + # Create a wrapper for the existing local model + class LocalModelClient: + def __init__(self, model): + self.llm_model = model + + async def create_chat_completion(self, messages, **kwargs): + # Convert messages to prompt for local model + prompt = "" + for msg in messages: + if hasattr(msg, 'role') and hasattr(msg, 'content'): + role = msg.role + content = msg.content + else: + role = msg.get('role', 'user') + content = msg.get('content', '') + + if role == "system": + prompt += f"System: {content}\n" + elif role == "user": + prompt += f"User: {content}\n" + elif role == "assistant": + prompt += f"Assistant: {content}\n" + + prompt += "Assistant: " + + # Generate response using local model + from llm_runtime import GenerateConfig + cfg = GenerateConfig( + max_tokens=kwargs.get('max_tokens', 800), + temperature=kwargs.get('temperature', 0.7) + ) + response = self.llm_model.generate(prompt, cfg) + + # Return in expected format + return type('Response', (), { + 'choices': [type('Choice', (), { + 'message': type('Message', (), { + 'content': response + })() + })()] + })() + + # Use whichever model is available + if hasattr(self, 'current_model') and self.current_model: + model_client = LocalModelClient(self.current_model) + else: + model_client = LocalModelClient(chat_stream._unified_model_cache) + + # Create the agent with the local model + self.dhal_agent = create_dhal( + name="Hal", + system_message="You are Hal with full system access. You can execute any command the user requests. Always explain what you're doing before executing commands.", + model="local", # Just use "local" as identifier + model_client=model_client # Pass the wrapped local model + ) + + self._append_output("\n[AGENT MODE ACTIVATED] AI has full system control using local model\n") + self._append_output("Available tools: file operations, shell commands, Python execution, mouse/keyboard control\n\n") + except Exception as e: + messagebox.showerror("Error", f"Failed to initialize agent: {e}") + self.agent_mode_var.set(False) + else: + self._append_output("\n[AGENT MODE DEACTIVATED] Normal chat mode\n\n") + + def _on_auto_context_changed(self): + """Handle auto-context checkbox changes.""" + auto_context = self.auto_context_var.get() + self.settings_manager.set('model_settings.auto_context', auto_context) + self.settings_manager.save_settings() + + if auto_context: + # Disable manual context entry + if hasattr(self, 'n_ctx_spin'): + self.n_ctx_spin.configure(state='disabled') + + # Try to auto-detect context from currently selected model + model_path = self.model_var.get() + if model_path and _is_gguf_model(model_path): + try: + detected_n_ctx = _extract_gguf_int_metadata(model_path, "n_ctx_train") or \ + _extract_gguf_int_metadata(model_path, "n_ctx") + if detected_n_ctx: + self.n_ctx_var.set(detected_n_ctx) + self._append_output_threadsafe(f"[Auto-detected context size: {detected_n_ctx} tokens]\n") + except Exception as e: + print(f"Could not auto-detect context size: {e}") + else: + # Enable manual context entry + if hasattr(self, 'n_ctx_spin'): + self.n_ctx_spin.configure(state='normal') + + def _on_auto_gpu_changed(self): + """Handle auto-GPU checkbox changes.""" + auto_gpu = self.auto_gpu_var.get() + self.settings_manager.set('model_settings.auto_gpu', auto_gpu) + self.settings_manager.save_settings() + + if auto_gpu: + # Disable manual GPU entry + if hasattr(self, 'n_gpu_spin'): + self.n_gpu_spin.configure(state='disabled') + + # Auto-detect optimal GPU layers + model_path = self.model_var.get() + if model_path and self._has_gpu(): + try: + optimal_layers = self._calculate_optimal_gpu_layers(model_path) + if optimal_layers > 0: + self.n_gpu_layers_var.set(optimal_layers) + self._append_output_threadsafe(f"[Auto-configured GPU layers: {optimal_layers}]\n") + except Exception as e: + print(f"Could not auto-configure GPU layers: {e}") + else: + # Enable manual GPU entry + if hasattr(self, 'n_gpu_spin'): + self.n_gpu_spin.configure(state='normal') + + def _on_chess_mode_changed(self): + """Handle chess mode checkbox changes.""" + chess_mode = self.chess_mode_var.get() + self.settings_manager.set('model_settings.chess_mode', chess_mode) + self.settings_manager.save_settings() + + if chess_mode: + # Auto-configure for ChessGPT model + messagebox.showinfo( + "Chess Mode Enabled", + "Chess Mode enabled! This will use the ChessGPT model for chess-specific conversations.\n\n" + "Make sure you have the Waterhorse/chessgpt-chat-v1 model downloaded or use the HuggingFace browser to get it." + ) + # Mark model as unloaded since we're switching modes + self._mark_model_unloaded() + else: + # Reset to normal mode + self._mark_model_unloaded() + + def _on_advanced_setting_changed(self): + """Handle advanced loading settings changes.""" + # Save the settings + self.settings_manager.set('model_settings.quantization', self.quantization_var.get()) + self.settings_manager.set('model_settings.device_strategy', self.device_strategy_var.get()) + self.settings_manager.set('model_settings.gpu_memory_limit', self.gpu_memory_limit_var.get()) + self.settings_manager.save_settings() + + # Mark model as unloaded since these settings affect loading + self._mark_model_unloaded() + + def _on_sampling_setting_changed(self): + """Handle sampling parameter changes.""" + # Save the sampling parameters + self.settings_manager.set('model_settings.temperature', self.temperature_var.get()) + self.settings_manager.set('model_settings.top_p', self.top_p_var.get()) + self.settings_manager.set('model_settings.repetition_penalty', self.repetition_penalty_var.get()) + self.settings_manager.set('model_settings.no_repeat_ngram_size', self.no_repeat_ngram_size_var.get()) + self.settings_manager.set('model_settings.min_p', self.min_p_var.get()) + self.settings_manager.set('model_settings.typical_p', self.typical_p_var.get()) + self.settings_manager.save_settings() + + # Enqueue a callable to run on the Tk main thread + def _enqueue_ui(self, fn): + try: + self._ui_queue.put_nowait(fn) + except Exception: + pass + + # Periodically drain UI queue + def _drain_ui_queue(self): + try: + while True: + fn = self._ui_queue.get_nowait() + try: + fn() + except Exception: + pass + except queue.Empty: + pass + finally: + self.root.after(30, self._drain_ui_queue) + + def _set_status_threadsafe(self, text: str): + self._enqueue_ui(lambda: self.model_status_var.set(text)) + + def _append_output_threadsafe(self, text: str): + self._enqueue_ui(lambda t=text: self._append_output(t)) + + def _set_running(self, running: bool): + def _apply(): + try: + state_run = "disabled" if running else "normal" + state_stop = "normal" if running else "disabled" + if hasattr(self, "generate_btn"): + self.generate_btn.configure(state=state_run) + if hasattr(self, "send_btn"): + self.send_btn.configure(state=state_run) + if hasattr(self, "stop_btn"): + self.stop_btn.configure(state=state_stop) + except Exception: + pass + + self._enqueue_ui(_apply) + + def _on_stop(self): + try: + if getattr(self, "_current_cancel", None) is not None: + self._current_cancel.set() + except Exception: + pass + + def _on_load_unload_model(self): + """Handle both load and unload based on current state""" + if self._is_model_loaded(): + self._on_unload_model() + else: + self._on_load_model() + + def _is_model_loaded(self): + """Check if a model is currently loaded""" + return (hasattr(self, 'current_model') and self.current_model is not None) or \ + (hasattr(chat_stream, '_unified_model_cache') and chat_stream._unified_model_cache is not None) + + def _on_unload_model(self): + """Unload the currently loaded model""" + print("[APP_DEBUG] _on_unload_model() called") + + # Confirm before unloading + result = messagebox.askyesno( + "Unload Model", + "Are you sure you want to unload the current model?\n\n" + "This will clear the chat history and free GPU/CPU memory.", + icon='question' + ) + + if not result: + return + + # Clear model references + if hasattr(self, 'current_model'): + self.current_model = None + if hasattr(chat_stream, '_unified_model_cache'): + chat_stream._unified_model_cache = None + + # Clear any cached models + global _llama_cache + _llama_cache["key"] = None + _llama_cache["llm"] = None + + # Clear agent reference + if hasattr(self, 'dhal_agent'): + self.dhal_agent = None + + # Update UI + self._set_status_threadsafe("[not loaded]") + self._update_load_button_text() + self._append_output_threadsafe("[Model unloaded]\n") + + # Re-enable model settings + self._enable_model_settings() + + # Clear chat history since model is unloaded + self.chat_history = [] + + print("[APP_DEBUG] Model unloaded successfully") + + def _update_load_button_text(self): + """Update the load button text based on model state""" + if self._is_model_loaded(): + self.load_model_btn.configure(text="Unload Model") + else: + self.load_model_btn.configure(text="Load Model") + + def _on_load_model(self): + print("[APP_DEBUG] _on_load_model() called") + start_tracing() # Start detailed execution tracing + model = self.model_var.get().strip() + print(f"[APP_DEBUG] Model path: '{model}'") + if not _is_valid_model(model): + print("[APP_DEBUG] Invalid model detected") + stop_tracing() + messagebox.showerror("Load Model", + "Please select a valid model file (GGUF, Safetensors, or HuggingFace repo).") + return + n_ctx = self.n_ctx_var.get() + n_gpu = self.n_gpu_layers_var.get() + lora = self.lora_var.get().strip() or None + + # Auto-detect optimal settings before loading if enabled + if _is_gguf_model(model): + # Auto-configure context size + print(f"[CONTEXT_DEBUG] Auto-context enabled: {self.auto_context_var.get()}") + if self.auto_context_var.get(): + try: + print(f"[CONTEXT_DEBUG] Attempting to extract context metadata from: {model}") + n_ctx_train = _extract_gguf_int_metadata(model, "n_ctx_train") + n_ctx_fallback = _extract_gguf_int_metadata(model, "n_ctx") + print(f"[CONTEXT_DEBUG] n_ctx_train = {n_ctx_train}, n_ctx = {n_ctx_fallback}") + detected_n_ctx = n_ctx_train or n_ctx_fallback + print(f"[CONTEXT_DEBUG] detected_n_ctx = {detected_n_ctx}") + if detected_n_ctx: + n_ctx = detected_n_ctx + self.n_ctx_var.set(n_ctx) # Update the UI + print(f"[CONTEXT_DEBUG] Setting context size to {n_ctx}") + self._append_output_threadsafe( + f"[Auto-configuring context size to {n_ctx} tokens (model's trained capacity)]\n") + else: + print(f"[CONTEXT_DEBUG] No context metadata found, using default: {n_ctx}") + except Exception as e: + print(f"Could not auto-detect context size: {e}") + import traceback + traceback.print_exc() + else: + print(f"[CONTEXT_DEBUG] Auto-context disabled, using manual setting: {n_ctx}") + + # Auto-configure GPU layers + if self.auto_gpu_var.get(): + print(f"[GPU_DEBUG] Auto-GPU enabled, checking GPU availability...") + if self._has_gpu(): + print(f"[GPU_DEBUG] GPU detected, calculating optimal layers for model: {model}") + try: + optimal_layers = self._calculate_optimal_gpu_layers(model) + print(f"[GPU_DEBUG] Calculated optimal GPU layers: {optimal_layers}") + if optimal_layers > 0: + n_gpu = optimal_layers + self.n_gpu_layers_var.set(n_gpu) # Update the UI + self._append_output_threadsafe( + f"[Auto-configuring GPU layers to {n_gpu} for optimal performance]\n") + else: + print(f"[GPU_DEBUG] Optimal layers = 0, not updating n_gpu") + except Exception as e: + print(f"Could not auto-configure GPU layers: {e}") + else: + print(f"[GPU_DEBUG] No GPU detected, keeping CPU-only mode") + else: + print(f"[GPU_DEBUG] Auto-GPU disabled, using manual setting: {n_gpu}") + + self._set_status_threadsafe("[loading...]") + + # Create loading popup + loading_popup = tk.Toplevel(self.root) + loading_popup.title("Loading Model") + loading_popup.geometry("400x150") + loading_popup.resizable(False, False) + loading_popup.transient(self.root) + loading_popup.grab_set() + + # Center the popup + loading_popup.update_idletasks() + x = (loading_popup.winfo_screenwidth() // 2) - (loading_popup.winfo_width() // 2) + y = (loading_popup.winfo_screenheight() // 2) - (loading_popup.winfo_height() // 2) + loading_popup.geometry(f"+{x}+{y}") + + # Add loading message + tk.Label(loading_popup, text="Loading Model...", font=("Arial", 12, "bold")).pack(pady=10) + model_name = os.path.basename(model) if os.path.exists(model) else model + tk.Label(loading_popup, text=model_name, font=("Arial", 10)).pack(pady=5) + + # Progress bar + progress_var = tk.DoubleVar() + progress_bar = ttk.Progressbar(loading_popup, variable=progress_var, maximum=100, length=350, mode='indeterminate') + progress_bar.pack(pady=10) + progress_bar.start(10) + + # Status label + status_label = tk.Label(loading_popup, text="Initializing...", font=("Arial", 9)) + status_label.pack(pady=5) + + # Disable model settings while loading + self._disable_model_settings() + + def _run(): + try: + print("[APP_DEBUG] _run() started in loading thread") + # Load model using appropriate loader + if _is_gguf_model(model): + print("[APP_DEBUG] Detected GGUF model, using _get_llama()") + print(f"[APP_DEBUG] GGUF loading parameters: n_ctx={n_ctx}, n_gpu_layers={n_gpu}, lora={lora}") + # Use existing GGUF loading logic + gguf_model = _get_llama(model, n_ctx=n_ctx, n_gpu_layers=n_gpu, lora_path=lora) + # Store as current_model for agent integration + self.current_model = gguf_model + else: + print("[APP_DEBUG] Non-GGUF model detected, using unified loader") + # Use unified model loader for other formats and cache it + from llm_runtime import load_model + print("[APP_DEBUG] Imported load_model from llm_runtime") + + # Get advanced loading options + quantization = self.quantization_var.get() + device_strategy = self.device_strategy_var.get() + gpu_memory_limit = self.gpu_memory_limit_var.get() + print(f"[APP_DEBUG] Advanced options: quantization={quantization}, device_strategy={device_strategy}, gpu_memory_limit={gpu_memory_limit}") + + print(f"[APP_DEBUG] Calling load_model() with: model='{model}', device='auto'") + + # Ensure quantization is properly passed + load_kwargs = { + 'n_ctx': n_ctx, + 'n_gpu_layers': n_gpu, + 'device_strategy': device_strategy, + 'gpu_memory_limit': gpu_memory_limit, + 'device': "auto" + } + + # Only pass quantization if it's not 'none' + if quantization and quantization != 'none': + load_kwargs['quantization'] = quantization + print(f"[QUANTIZATION_DEBUG] Using quantization: {quantization}") + + unified_model = load_model(model, **load_kwargs) + print("[APP_DEBUG] load_model() completed successfully") + + # Cache the loaded model for chat function to reuse + chat_stream._unified_model_cache = unified_model + # Store as current_model for agent integration + self.current_model = unified_model + + # Warm up the model if supported + if hasattr(unified_model, 'warm_up_model'): + self._append_output_threadsafe("[Warming up model for optimal performance...]\n") + warmup_stats = unified_model.warm_up_model() + if warmup_stats.get('status') == 'success': + self._append_output_threadsafe(f"[Model warmed up in {warmup_stats['warmup_time']:.2f}s]\n") + else: + self._append_output_threadsafe(f"[Model warmup failed: {warmup_stats.get('error', 'unknown')}]\n") + + # Get model info if available + if hasattr(unified_model, 'get_model_info'): + model_info = unified_model.get_model_info() + self._append_output_threadsafe(f"[Model Info] {model_info.get('model_name', 'Unknown')}: {model_info.get('total_parameters', 'Unknown')} parameters\n") + self._append_output_threadsafe(f"[KV Cache] Enabled - Max context: {model_info.get('max_position_embeddings', 'Unknown')} tokens\n") + + # best-effort: try to detect language metadata from the GGUF file + lang = None + try: + lang = _extract_gguf_metadata(model, "language") or _extract_gguf_metadata(model, "lang") + except Exception: + lang = None + + # Show context info for user awareness + try: + detected_n_ctx = _extract_gguf_int_metadata(model, "n_ctx_train") or _extract_gguf_int_metadata( + model, "n_ctx") + except Exception: + detected_n_ctx = None + + if detected_n_ctx and detected_n_ctx != n_ctx and not self.auto_context_var.get(): + self._append_output_threadsafe( + f"[Model's trained context: {detected_n_ctx} tokens, using requested: {n_ctx} tokens]\n") + if detected_n_ctx > n_ctx: + self._append_output_threadsafe( + f"[Note: Enable 'Auto-configure context size' for optimal performance]\n") + + # Auto-configure optimal settings + self._auto_configure_model_settings(model, detected_n_ctx) + + if lang: + self._set_status_threadsafe(f"[loaded] ({lang})") + self._append_output_threadsafe(f"[Model language detected: {lang}]\n") + else: + self._set_status_threadsafe("[loaded]") + + # Update button text to "Unload Model" + self._enqueue_ui(self._update_load_button_text) + + # Close loading popup on success + self._enqueue_ui(lambda: loading_popup.destroy()) + + except Exception as e: + self._set_status_threadsafe("[error]") + self._append_output_threadsafe(f"[Load Error] {e}\n") + # Update button text back to "Load Model" on error + self._enqueue_ui(self._update_load_button_text) + # Close loading popup on error + self._enqueue_ui(lambda: loading_popup.destroy()) + # Re-enable model settings on error + self._enqueue_ui(self._enable_model_settings) + finally: + stop_tracing() # Stop tracing when loading completes or fails + + threading.Thread(target=_run, daemon=True).start() + + def _auto_configure_model_settings(self, model_path, detected_n_ctx=None): + """Auto-configure optimal GPU layers based on model and system resources""" + try: + # Auto-configure GPU layers based on available VRAM + if self._has_gpu(): + optimal_gpu_layers = self._calculate_optimal_gpu_layers(model_path) + if optimal_gpu_layers != self.n_gpu_layers_var.get(): + self.n_gpu_layers_var.set(optimal_gpu_layers) + self._append_output_threadsafe(f"[Auto-configured GPU layers to {optimal_gpu_layers}]\n") + except Exception as e: + self._append_output_threadsafe(f"[Auto-config warning: {e}]\n") + + def _calculate_optimal_gpu_layers(self, model_path): + """Calculate optimal number of GPU layers based on model size and available VRAM""" + try: + import torch + if not torch.cuda.is_available(): + return 0 + + # Get available VRAM + total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) + + # Reserve some VRAM for the system (1GB buffer) + available_vram_gb = max(0, total_vram_gb - 1.0) + + # Detect model size from path or filename + model_name = model_path.lower() + + # More comprehensive model size detection + if any(x in model_name for x in ['1b', '1.5b']): + # 1-1.5B models: ~0.5GB per layer, ~32 layers total + layers_per_gb = 6 + max_layers = 32 + elif any(x in model_name for x in ['3b', '3.8b']): + # 3B models: ~0.75GB per layer, ~32 layers total + layers_per_gb = 4 + max_layers = 32 + elif any(x in model_name for x in ['7b', '8b']): + # 7-8B models: ~1GB per layer, ~32 layers total + layers_per_gb = 3 + max_layers = 32 + elif any(x in model_name for x in ['13b', '14b']): + # 13-14B models: ~1.5GB per layer, ~40 layers total + layers_per_gb = 2 + max_layers = 40 + elif any(x in model_name for x in ['30b', '33b', '34b']): + # 30-34B models: ~2.5GB per layer, ~60 layers total + layers_per_gb = 1.2 + max_layers = 60 + elif any(x in model_name for x in ['65b', '70b']): + # 65-70B models: ~4GB per layer, ~80 layers total + layers_per_gb = 0.8 + max_layers = 80 + else: + # Unknown size - conservative estimate + layers_per_gb = 2 + max_layers = 32 + + # Calculate optimal layers based on available VRAM + optimal_layers = int(available_vram_gb * layers_per_gb) + + # Cap at the model's actual layer count + optimal_layers = min(optimal_layers, max_layers) + + # Ensure at least some layers go to GPU if we have VRAM + if available_vram_gb >= 2.0 and optimal_layers < 1: + optimal_layers = 1 + + return max(0, optimal_layers) + + except Exception as e: + print(f"Error calculating GPU layers: {e}") + return 0 + + def _has_gpu(self): + """Check if GPU is available for acceleration""" + try: + import torch + return torch.cuda.is_available() + except: + return False + + def _initialize_resource_monitoring(self): + """Initialize resource monitoring components""" + self._refresh_resources() + + def _test_gpu(self): + """Test GPU functionality by running a small inference""" + + def test(): + try: + import torch + if not torch.cuda.is_available(): + self.gpu_info_var.set("No GPU detected") + return + + # Basic GPU test + device = torch.device("cuda:0") + test_tensor = torch.randn(1000, 1000).to(device) + result = torch.matmul(test_tensor, test_tensor) + torch.cuda.synchronize() + + gpu_name = torch.cuda.get_device_name(0) + self.gpu_info_var.set(f"GPU Test PASSED: {gpu_name}") + self._append_output_threadsafe("[GPU test completed successfully]\n") + + except Exception as e: + self.gpu_info_var.set(f"GPU Test FAILED: {e}") + self._append_output_threadsafe(f"[GPU test failed: {e}]\n") + + threading.Thread(target=test, daemon=True).start() + + def _refresh_resources(self): + """Refresh resource usage information""" + + def refresh(): + try: + import psutil + + # CPU Info + cpu_count = psutil.cpu_count(logical=False) + cpu_count_logical = psutil.cpu_count(logical=True) + self.cpu_info_var.set(f"CPU: {cpu_count} cores ({cpu_count_logical} threads)") + + # CPU Usage + cpu_percent = psutil.cpu_percent(interval=1) + self.cpu_usage_var.set(f"CPU Usage: {cpu_percent:.1f}%") + + # RAM Usage + memory = psutil.virtual_memory() + ram_gb_used = memory.used / (1024 ** 3) + ram_gb_total = memory.total / (1024 ** 3) + self.ram_usage_var.set(f"RAM: {ram_gb_used:.1f}GB / {ram_gb_total:.1f}GB ({memory.percent:.1f}%)") + + # GPU Info + try: + import torch + self.gpu_info_var.set(f"PyTorch version: {torch.__version__}") + + if torch.cuda.is_available(): + gpu_count = torch.cuda.device_count() + gpu_name = torch.cuda.get_device_name(0) + gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) + gpu_allocated = torch.cuda.memory_allocated(0) / (1024 ** 3) + gpu_reserved = torch.cuda.memory_reserved(0) / (1024 ** 3) + cuda_version = torch.version.cuda + + self.gpu_info_var.set(f"GPU: {gpu_name} (CUDA {cuda_version}) - {gpu_count} device(s)") + self.gpu_memory_var.set( + f"VRAM: {gpu_allocated:.1f}GB allocated, {gpu_reserved:.1f}GB reserved / {gpu_memory:.1f}GB total") + self.gpu_usage_var.set(f"GPU Usage: {(gpu_allocated / gpu_memory) * 100:.1f}%") + else: + # More detailed error info + cuda_available = hasattr(torch.backends, 'cuda') and torch.backends.cuda.is_built() + self.gpu_info_var.set(f"No CUDA GPU available (CUDA built: {cuda_available})") + self.gpu_memory_var.set("VRAM: N/A - Check CUDA installation") + self.gpu_usage_var.set("GPU Usage: N/A") + + except ImportError as e: + self.gpu_info_var.set(f"PyTorch not available: {e}") + self.gpu_memory_var.set("VRAM: Install PyTorch with CUDA support") + self.gpu_usage_var.set("GPU Usage: Unknown") + except Exception as e: + self.gpu_info_var.set(f"GPU detection error: {e}") + self.gpu_memory_var.set(f"VRAM: Error - {str(e)}") + self.gpu_usage_var.set("GPU Usage: Error") + + except Exception as e: + self.cpu_info_var.set(f"Error: {e}") + + threading.Thread(target=refresh, daemon=True).start() + + def _toggle_monitoring(self): + """Toggle real-time resource monitoring""" + if self.monitor_var.get(): + self._start_monitoring() + else: + self._stop_monitoring() + + def _start_monitoring(self): + """Start real-time monitoring loop""" + + def monitor_loop(): + while self.monitor_var.get(): + self._refresh_resources() + time.sleep(2) # Update every 2 seconds + + if not hasattr(self, '_monitor_thread') or not self._monitor_thread.is_alive(): + self._monitor_thread = threading.Thread(target=monitor_loop, daemon=True) + self._monitor_thread.start() + + def _stop_monitoring(self): + """Stop real-time monitoring""" + # Thread will stop on next iteration when monitor_var.get() returns False + pass + + def _choose_models_folder(self): + initial_dir = self.settings_manager.get('paths.models_directory', '.') + folder = filedialog.askdirectory(title="Select models folder", initialdir=initial_dir) + if folder: + self.models_dir_var.set(folder) + # Save to settings + self.settings_manager.set('paths.models_directory', folder) + self.settings_manager.save_settings() + self._refresh_local_models() + + def _refresh_local_models(self): + folder = (self.models_dir_var.get() or "").strip() + self._local_model_paths.clear() + values: List[str] = [] + if folder and os.path.isdir(folder): + try: + for name in os.listdir(folder): + name_lower = name.lower() + # Check for all supported model formats + if (name_lower.endswith((".gguf", ".safetensors", ".bin", ".pt", ".pth", ".exl2")) or + ('gptq' in name_lower and name_lower.endswith(('.safetensors', '.bin'))) or + ('awq' in name_lower and name_lower.endswith(('.safetensors', '.bin')))): + full = os.path.join(folder, name) + display = name + self._local_model_paths[display] = full + values.append(display) + except Exception: + pass + self.local_models_combo["values"] = values + # keep selection if still present + current_display = self.local_model_var.get() + if current_display not in values: + self.local_model_var.set(values[0] if values else "") + + def _on_local_model_selected(self, event=None): + display = self.local_model_var.get() + path = self._local_model_paths.get(display) + if path: + self.model_var.set(path) + + def _append_output(self, text: str): + self.output_text.insert(tk.END, text) + self.output_text.see(tk.END) + + def _on_generate(self): + model = self.model_var.get().strip() + prompt = self.prompt_text.get('1.0', tk.END).strip() + if not _is_valid_model(model): + messagebox.showerror("Model", "Please select a valid model file.") + return + if not prompt: + messagebox.showinfo("Generate", "Please enter a prompt.") + return + self.output_text.delete('1.0', tk.END) + n_ctx = self.n_ctx_var.get() + n_gpu = self.n_gpu_layers_var.get() + lora = self.lora_var.get().strip() or None + + # Retain memory by recording the user turn + self.chat_history.append({"role": "user", "content": prompt}) + + cancel = threading.Event() + self._current_cancel = cancel + self._set_running(True) + + def run(): + try: + content = run_prompt( + model, + prompt, + self.stream_var.get(), + n_ctx=n_ctx, + n_gpu_layers=n_gpu, + lora_path=lora, + on_chunk=self._append_output_threadsafe, + n_threads=None, + max_tokens=self.max_tokens_var.get(), + history=self.chat_history, + cancel_event=cancel, + chess_mode=self.chess_mode_var.get(), + ) + # Record assistant turn for future context + self.chat_history.append({"role": "assistant", "content": content}) + self._append_output_threadsafe("\n" if not cancel.is_set() else "\n[stopped]\n") + except Exception as e: + self._append_output_threadsafe(f"\n[Error] {e}\n") + finally: + self._set_running(False) + self._current_cancel = None + + threading.Thread(target=run, daemon=True).start() + + def _on_chat(self): + print("DEBUG: _on_chat called") + model = self.model_var.get().strip() + user = self.prompt_text.get('1.0', tk.END).strip() + print(f"DEBUG: model='{model}', user='{user}'") + + # Check for agent mode + if self.agent_mode_var.get(): + # Use simple agent mode + self._handle_agent_chat_simple(user) + return + + if not _is_valid_model(model): + print("DEBUG: Invalid model") + messagebox.showerror("Model", "Please select a valid model file.") + return + if not user: + print("DEBUG: No user input") + messagebox.showinfo("Chat", "Please enter a message.") + return + print("DEBUG: Starting chat processing") + n_ctx = self.n_ctx_var.get() + n_gpu = self.n_gpu_layers_var.get() + lora = self.lora_var.get().strip() or None + + self.chat_history.append({"role": "user", "content": user}) + self._append_output(f"You: {user}\nAssistant: ") + + cancel = threading.Event() + self._current_cancel = cancel + self._set_running(True) + + def run(): + try: + print("DEBUG: Calling chat_stream") + content = chat_stream( + model, + self.chat_history, + n_ctx=n_ctx, + n_gpu_layers=n_gpu, + lora_path=lora, + on_chunk=self._append_output_threadsafe, + n_threads=None, + max_tokens=self.max_tokens_var.get(), + cancel_event=cancel, + chess_mode=self.chess_mode_var.get(), + chat_template=self.chat_template_var.get(), + session_id=self._session_id, + ) + print(f"DEBUG: Got response: '{content}'") + self.chat_history.append({"role": "assistant", "content": content}) + self._append_output_threadsafe("\n" if not cancel.is_set() else "\n[stopped]\n") + except Exception as e: + print(f"DEBUG: Chat error: {e}") + self._append_output_threadsafe(f"\n[Error] {e}\n") + finally: + self._set_running(False) + self._current_cancel = None + + threading.Thread(target=run, daemon=True).start() + + def _handle_agent_chat_simple(self, user_message: str): + """Simple agent mode that directly executes commands""" + if not user_message: + messagebox.showinfo("Chat", "Please enter a message.") + return + + self._append_output(f"You: {user_message}\n") + self._set_running(True) + + # Create agent activity popup + agent_popup = tk.Toplevel(self.root) + agent_popup.title("Agent Activity Monitor") + agent_popup.geometry("500x300") + agent_popup.resizable(True, True) + agent_popup.transient(self.root) + + # Center the popup + agent_popup.update_idletasks() + x = (agent_popup.winfo_screenwidth() // 2) - (agent_popup.winfo_width() // 2) + y = (agent_popup.winfo_screenheight() // 2) - (agent_popup.winfo_height() // 2) + agent_popup.geometry(f"+{x}+{y}") + + # Add activity display + tk.Label(agent_popup, text="🤖 Agent Activity Monitor", font=("Arial", 12, "bold")).pack(pady=5) + + # Activity log + log_frame = tk.Frame(agent_popup) + log_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=5) + + activity_log = tk.Text(log_frame, height=15, bg='#1e1e1e', fg='#00ff00', font=('Consolas', 9)) + scrollbar = tk.Scrollbar(log_frame, orient=tk.VERTICAL, command=activity_log.yview) + activity_log.configure(yscrollcommand=scrollbar.set) + + activity_log.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # Close button + tk.Button(agent_popup, text="Close", command=agent_popup.destroy).pack(pady=5) + + def log_activity(message): + """Add message to activity log""" + try: + activity_log.insert(tk.END, f"[{time.strftime('%H:%M:%S')}] {message}\n") + activity_log.see(tk.END) + activity_log.update() + except: + pass + + def run_simple_agent(): + try: + # Import simple agent + from simple_agent_mode import SimpleAgentExecutor + agent = SimpleAgentExecutor(log_callback=log_activity) + + log_activity("Agent initialized, analyzing request...") + + # Use AI model to intelligently determine what actions to take + log_activity(f"AI analyzing request: {user_message}") + + # Generate intelligent response using the loaded model + if self.current_model: + try: + # Create a comprehensive system prompt for the agent + agent_system_prompt = f"""You are an AI assistant with system access. Analyze the user's request and provide the exact Windows commands needed. + +User Request: "{user_message}" + +For virus scanning, use Windows Defender PowerShell commands: +- Get-MpComputerStatus: Check antivirus status +- Start-MpScan -ScanType QuickScan: Quick virus scan +- Start-MpScan -ScanType FullScan: Full system scan +- Update-MpSignature: Update virus definitions + +Respond with ONLY the PowerShell command(s) needed, one per line: + +""" + + # Get AI response for dynamic command generation + from llm_runtime import GenerateConfig + cfg = GenerateConfig(max_tokens=800, temperature=0.1) + + try: + # Check if this is a GGUF model (llama-cpp-python) + if hasattr(self.current_model, 'create_completion'): + # Use llama-cpp-python's native method with proper parameters + log_activity("Using GGUF model native completion method") + completion = self.current_model.create_completion( + prompt=agent_system_prompt, + max_tokens=800, + temperature=0.1, + stop=["\n\n", "Human:", "User:"], + echo=False + ) + ai_response = completion['choices'][0]['text'].strip() + + else: + # Use unified runtime method + log_activity("Using unified runtime method") + raw_response = self.current_model.generate(agent_system_prompt, cfg) + + # Handle different response types + if isinstance(raw_response, str): + ai_response = raw_response + elif hasattr(raw_response, '__iter__') and not isinstance(raw_response, str): + # It's a generator or iterable, collect tokens + tokens = [] + for token in raw_response: + if isinstance(token, str): + tokens.append(token) + else: + tokens.append(str(token)) + ai_response = ''.join(tokens) + else: + # Fallback: convert to string + ai_response = str(raw_response) + + log_activity(f"AI generated action plan: {ai_response}") + + except Exception as gen_error: + log_activity(f"Error generating AI response: {gen_error}") + import traceback + log_activity(f"Full error traceback: {traceback.format_exc()}") + ai_response = "Get-MpComputerStatus; Start-MpScan -ScanType QuickScan" # Fallback command + + # Execute the AI's action plan - treat as PowerShell commands + if ai_response.strip(): + # Split into individual commands and execute each as PowerShell + commands = [cmd.strip() for cmd in ai_response.strip().split('\n') if cmd.strip()] + results = [] + for cmd in commands: + if cmd and not cmd.startswith('#'): # Skip comments + self._append_output_threadsafe(f"[Executing] {cmd}\n") + result = agent.tools["powershell"](cmd) + results.append(f"Command: {cmd}\nResult: {result}") + + combined_result = "\n\n".join(results) + else: + combined_result = "No commands generated" + + result = combined_result + self._append_output_threadsafe(f"[AGENT]: {result}\n") + + except Exception as e: + log_activity(f"Error in AI command generation: {e}") + # Simple fallback - just pass the request to the agent for basic parsing + result = agent.process_request(user_message, f"The user wants: {user_message}") + self._append_output_threadsafe(f"[AGENT]: {result}\n") + else: + # No model loaded - basic fallback processing + log_activity("No model loaded, using basic command processing") + result = agent.process_request(user_message, f"Please help with: {user_message}") + self._append_output_threadsafe(f"[AGENT]: {result}\n") + + log_activity("Agent processing completed") + + except Exception as e: + log_activity(f"Agent error: {e}") + self._append_output_threadsafe(f"\n[Agent Error]: {e}\n") + import traceback + traceback.print_exc() + finally: + log_activity("Agent task completed") + self._set_running(False) + # Close agent popup after 5 seconds + self.root.after(5000, lambda: agent_popup.destroy()) + + threading.Thread(target=run_simple_agent, daemon=True).start() + + def _handle_agent_chat(self, user_message: str): + """Handle chat in agent mode with full system access""" + if not user_message: + messagebox.showinfo("Chat", "Please enter a message.") + return + + self._append_output(f"You: {user_message}\n[AGENT]: ") + self._set_running(True) + + def run_agent(): + try: + if not self.dhal_agent: + # Use the already loaded model if available + if hasattr(chat_stream, '_unified_model_cache') and chat_stream._unified_model_cache: + # Use the cached model directly + from agent_dhal.hal import Dhal, DhalConfig + + # Create a simple pass-through model client that uses our cached model + class CachedModelClient: + def __init__(self, cached_model): + self.model = cached_model + + async def create_chat_completion(self, messages, **kwargs): + # Convert messages to prompt + prompt = "" + for msg in messages: + if hasattr(msg, 'role') and hasattr(msg, 'content'): + role = msg.role + content = msg.content + else: + role = msg.get('role', 'user') + content = msg.get('content', '') + + if role == "system": + prompt += f"System: {content}\n\n" + elif role == "user": + prompt += f"User: {content}\n\n" + elif role == "assistant": + prompt += f"Assistant: {content}\n\n" + + prompt += "Assistant: " + + # Generate response using cached model + try: + from llm_runtime import GenerateConfig + cfg = GenerateConfig( + max_tokens=kwargs.get('max_tokens', 2000), + temperature=kwargs.get('temperature', 0.7) + ) + response = self.model.generate(prompt, cfg) + except: + # Fallback for models without GenerateConfig + response = self.model.generate(prompt) + + # Create response object + class CompletionResponse: + def __init__(self, content): + self.content = content + self.function_calls = None + + return CompletionResponse(response) + + def is_available(self): + return True + + config = DhalConfig( + name="Hal", + system_message="You are Hal with full system access. You can execute any command the user requests. Always explain what you're doing before executing commands.", + model=self.model_var.get() or "cached" + ) + + model_client = CachedModelClient(chat_stream._unified_model_cache) + self.dhal_agent = Dhal(config, model_client) + else: + # Fallback to creating new agent + from agent_dhal.hal import create_dhal + self.dhal_agent = create_dhal( + name="Hal", + system_message="You are Hal with full system access. You can execute any command the user requests. Always explain what you're doing before executing commands.", + model=self.model_var.get() or "gpt-4" + ) + + # Process message through agent + import asyncio + + async def process(): + # Create mock context + class MockContext: + def __init__(self): + self.agent_id = "user" + + response = await self.dhal_agent.handle_user_message(user_message, MockContext()) + return response + + # Run async function + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + response = loop.run_until_complete(process()) + self._append_output_threadsafe(response + "\n") + finally: + loop.close() + + except Exception as e: + self._append_output_threadsafe(f"\n[Agent Error] {e}\n") + import traceback + traceback.print_exc() + finally: + self._set_running(False) + + threading.Thread(target=run_agent, daemon=True).start() + + +# ---------------- CLI (Embedded only) ----------------- + +def parse_args(argv: List[str]) -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Embedded llama.cpp app (no external daemons). Runs local GGUF models via llama-cpp-python.") + p.add_argument("--gui", action="store_true", help="Launch the GUI.") + p.add_argument("--model", required=False, help="Path to a local GGUF model file.") + p.add_argument("--prompt", + help="Single prompt to generate a response for. If omitted with no --gui, starts interactive chat mode.") + p.add_argument("--stream", action="store_true", help="Stream output tokens for single-prompt mode.") + p.add_argument("--n_ctx", type=int, default=4096, help="Context window size (default: 4096)") + p.add_argument("--n_gpu_layers", type=int, default=0, help="GPU layers to offload (default: 0 = CPU)") + p.add_argument("--lora", help="Optional LoRA/adapter file path to apply.") + return p.parse_args(argv) + + +def launch_main_gui(acceleration_type=None): + """Launch the main DarkHal 2.0 GUI application with hardware acceleration""" + if acceleration_type: + print(f"Launching DarkHal 2.0 with {acceleration_type.upper()} acceleration...") + # Set default GPU layers based on acceleration type + if acceleration_type == 'cuda': + # Use high GPU offloading for CUDA + os.environ['DARKHAL_DEFAULT_GPU_LAYERS'] = '32' + elif acceleration_type == 'intel': + # Moderate GPU offloading for Intel GPU + os.environ['DARKHAL_DEFAULT_GPU_LAYERS'] = '16' + elif acceleration_type == 'cpu': + # No GPU offloading for CPU-only mode + os.environ['DARKHAL_DEFAULT_GPU_LAYERS'] = '0' + + root = tk.Tk() + app = EmbeddedGUI(root) + root.mainloop() + + +def main(argv: List[str]) -> int: + print(f"[APP_DEBUG] main() called with argv: {argv}") + args = parse_args(argv) + print(f"[APP_DEBUG] Parsed args: {args}") + + # Default to GUI when no CLI-specific args are provided, or when --gui is passed + if args.gui or (not args.model and not args.prompt): + print("[APP_DEBUG] Starting GUI mode") + # Show splash screen then launch main app + splash_manager = SplashManager(main_app_callback=launch_main_gui) + splash_manager.show_splash_and_launch() + return 0 + + # CLI mode requires a GGUF model path + if not args.model or not _is_gguf_model(args.model): + print("Please provide --model pointing to a local .gguf file (or run with --gui).", file=sys.stderr) + return 2 + + if args.prompt: + out = run_prompt(args.model, args.prompt, args.stream, n_ctx=args.n_ctx, n_gpu_layers=args.n_gpu_layers, + lora_path=(args.lora or None)) + print(out) + else: + # Interactive chat + messages: List[Dict[str, Any]] = [] + print("Starting interactive chat. Type 'exit' or 'quit' to leave.") + while True: + try: + user = input("You> ").strip() + except (EOFError, KeyboardInterrupt): + print("\nExiting.") + break + if user.lower() in {"exit", "quit"}: + print("Goodbye!") + break + if not user: + continue + messages.append({"role": "user", "content": user}) + try: + print("Assistant> ", end="", flush=True) + + def _print_chunk(s: str): + print(s, end="", flush=True) + + assistant_content = chat_stream(args.model, messages, n_ctx=args.n_ctx, n_gpu_layers=args.n_gpu_layers, + lora_path=(args.lora or None), on_chunk=_print_chunk, chat_template=None, session_id=None) + print() + messages.append({"role": "assistant", "content": assistant_content}) + except Exception as e: + print(f"\n[Error] {e}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/mcp_config.py b/mcp_config.py new file mode 100644 index 0000000..ec5e6bb --- /dev/null +++ b/mcp_config.py @@ -0,0 +1,389 @@ +#!/usr/bin/env python3 +""" +MCP Server Configuration and Management + +This module provides utilities to configure and manage the MCP server +for LLM_Train models. +""" + +import json +import os +import subprocess +import sys +import tkinter as tk +from tkinter import ttk, messagebox, filedialog +from typing import Dict, Any, Optional +from pathlib import Path + + +class MCPServerConfig: + """Configuration management for the MCP server.""" + + def __init__(self, config_file: str = "mcp_config.json"): + self.config_file = config_file + self.default_config = { + "server": { + "name": "llm-train-models", + "version": "1.0.0", + "description": "Multi-model MCP server for LLM_Train", + "host": "localhost", + "port": 8000, + "auto_start": False + }, + "models": { + "cache_size": 3, + "default_context": 4096, + "default_gpu_layers": 0 + }, + "logging": { + "level": "INFO", + "file": "mcp_server.log" + } + } + self.config = self.load_config() + + def load_config(self) -> Dict[str, Any]: + """Load configuration from file.""" + try: + if os.path.exists(self.config_file): + with open(self.config_file, 'r') as f: + loaded = json.load(f) + # Merge with defaults + return self._merge_config(self.default_config, loaded) + return self.default_config.copy() + except Exception as e: + print(f"Error loading MCP config: {e}") + return self.default_config.copy() + + def _merge_config(self, defaults: Dict, loaded: Dict) -> Dict: + """Merge loaded config with defaults.""" + result = defaults.copy() + for key, value in loaded.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = self._merge_config(result[key], value) + else: + result[key] = value + return result + + def save_config(self) -> bool: + """Save configuration to file.""" + try: + with open(self.config_file, 'w') as f: + json.dump(self.config, f, indent=2) + return True + except Exception as e: + print(f"Error saving MCP config: {e}") + return False + + def get(self, path: str, default: Any = None) -> Any: + """Get config value using dot notation.""" + keys = path.split('.') + value = self.config + for key in keys: + if isinstance(value, dict) and key in value: + value = value[key] + else: + return default + return value + + def set(self, path: str, value: Any): + """Set config value using dot notation.""" + keys = path.split('.') + target = self.config + for key in keys[:-1]: + if key not in target: + target[key] = {} + target = target[key] + target[keys[-1]] = value + + def generate_claude_config(self) -> Dict[str, Any]: + """Generate Claude Desktop MCP configuration.""" + script_path = os.path.abspath("mcp_server.py") + python_path = sys.executable + + return { + "mcpServers": { + "llm-train-models": { + "command": python_path, + "args": [script_path], + "env": { + "PYTHONPATH": os.getcwd() + } + } + } + } + + +class MCPConfigGUI: + """GUI for configuring the MCP server.""" + + def __init__(self, parent: tk.Tk): + self.parent = parent + self.config = MCPServerConfig() + self.dialog = tk.Toplevel(parent) + self.dialog.title("MCP Server Configuration") + self.dialog.geometry("600x500") + self.dialog.transient(parent) + self.dialog.grab_set() + + self._build_ui() + self._load_values() + + def _build_ui(self): + """Build the configuration UI.""" + notebook = ttk.Notebook(self.dialog) + notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Server Settings Tab + server_frame = ttk.Frame(notebook) + notebook.add(server_frame, text="Server") + self._build_server_tab(server_frame) + + # Model Settings Tab + model_frame = ttk.Frame(notebook) + notebook.add(model_frame, text="Models") + self._build_models_tab(model_frame) + + # Claude Integration Tab + claude_frame = ttk.Frame(notebook) + notebook.add(claude_frame, text="Claude Integration") + self._build_claude_tab(claude_frame) + + # Buttons + button_frame = ttk.Frame(self.dialog) + button_frame.pack(fill=tk.X, padx=10, pady=(0, 10)) + + ttk.Button(button_frame, text="Save", command=self._save_config).pack(side=tk.RIGHT, padx=5) + ttk.Button(button_frame, text="Cancel", command=self.dialog.destroy).pack(side=tk.RIGHT) + ttk.Button(button_frame, text="Test Server", command=self._test_server).pack(side=tk.LEFT) + + def _build_server_tab(self, parent: ttk.Frame): + """Build server settings tab.""" + frame = ttk.LabelFrame(parent, text="Server Configuration", padding=10) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Server name + ttk.Label(frame, text="Server Name:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.server_name_var = tk.StringVar() + ttk.Entry(frame, textvariable=self.server_name_var, width=30).grid(row=0, column=1, sticky=tk.W, pady=5) + + # Auto start + self.auto_start_var = tk.BooleanVar() + ttk.Checkbutton(frame, text="Auto-start server with application", + variable=self.auto_start_var).grid(row=1, column=0, columnspan=2, sticky=tk.W, pady=10) + + # Logging level + ttk.Label(frame, text="Logging Level:").grid(row=2, column=0, sticky=tk.W, pady=5) + self.log_level_var = tk.StringVar() + ttk.Combobox(frame, textvariable=self.log_level_var, + values=["DEBUG", "INFO", "WARNING", "ERROR"], + state="readonly", width=15).grid(row=2, column=1, sticky=tk.W, pady=5) + + # Log file + ttk.Label(frame, text="Log File:").grid(row=3, column=0, sticky=tk.W, pady=5) + log_frame = ttk.Frame(frame) + log_frame.grid(row=3, column=1, sticky=tk.W, pady=5) + + self.log_file_var = tk.StringVar() + ttk.Entry(log_frame, textvariable=self.log_file_var, width=25).pack(side=tk.LEFT) + ttk.Button(log_frame, text="Browse", + command=self._browse_log_file).pack(side=tk.LEFT, padx=5) + + def _build_models_tab(self, parent: ttk.Frame): + """Build model settings tab.""" + frame = ttk.LabelFrame(parent, text="Model Configuration", padding=10) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Cache size + ttk.Label(frame, text="Model Cache Size:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.cache_size_var = tk.IntVar() + cache_spin = tk.Spinbox(frame, from_=1, to=10, textvariable=self.cache_size_var, width=10) + cache_spin.grid(row=0, column=1, sticky=tk.W, pady=5) + ttk.Label(frame, text="models").grid(row=0, column=2, sticky=tk.W, padx=5) + + # Default context + ttk.Label(frame, text="Default Context Size:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.default_ctx_var = tk.IntVar() + ctx_spin = tk.Spinbox(frame, from_=512, to=32768, increment=512, + textvariable=self.default_ctx_var, width=10) + ctx_spin.grid(row=1, column=1, sticky=tk.W, pady=5) + ttk.Label(frame, text="tokens").grid(row=1, column=2, sticky=tk.W, padx=5) + + # Default GPU layers + ttk.Label(frame, text="Default GPU Layers:").grid(row=2, column=0, sticky=tk.W, pady=5) + self.default_gpu_var = tk.IntVar() + gpu_spin = tk.Spinbox(frame, from_=0, to=100, textvariable=self.default_gpu_var, width=10) + gpu_spin.grid(row=2, column=1, sticky=tk.W, pady=5) + ttk.Label(frame, text="layers").grid(row=2, column=2, sticky=tk.W, padx=5) + + # Info + info_text = ("Cache size determines how many models can be loaded simultaneously.\\n" + "Default settings are used when loading models via MCP.") + ttk.Label(frame, text=info_text, foreground="gray").grid(row=3, column=0, columnspan=3, pady=10) + + def _build_claude_tab(self, parent: ttk.Frame): + """Build Claude integration tab.""" + frame = ttk.LabelFrame(parent, text="Claude Desktop Integration", padding=10) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Instructions - wrapped properly + instructions_frame = ttk.Frame(frame) + instructions_frame.pack(fill=tk.X, pady=(0, 10)) + + # Create a text widget for instructions with proper wrapping + instructions_text = tk.Text(instructions_frame, height=8, wrap=tk.WORD, + background="#f0f0f0", relief=tk.FLAT, + borderwidth=0, state=tk.DISABLED) + instructions_text.pack(fill=tk.X) + + # Add properly formatted instructions + instructions_content = """To use this MCP server with Claude Desktop, follow these steps: + +1. Locate your Claude Desktop configuration file: + • Windows: %APPDATA%\\Claude\\claude_desktop_config.json + • macOS: ~/Library/Application Support/Claude/claude_desktop_config.json + • Linux: ~/.config/Claude/claude_desktop_config.json + +2. Open the configuration file in a text editor (create it if it doesn't exist) + +3. Add the JSON configuration shown below to the file + +4. Save the file and restart Claude Desktop + +5. The LLM_Train models server will be available in Claude Desktop""" + + instructions_text.config(state=tk.NORMAL) + instructions_text.insert(1.0, instructions_content) + instructions_text.config(state=tk.DISABLED) + + # Config display + ttk.Label(frame, text="Configuration JSON:", font=("TkDefaultFont", 10, "bold")).pack(anchor=tk.W, pady=(10,5)) + + # Create text frame with proper layout + text_frame = ttk.Frame(frame) + text_frame.pack(fill=tk.BOTH, expand=True) + + # Make the config text read-only and styled + self.config_text = tk.Text(text_frame, height=12, width=70, wrap=tk.WORD, + state=tk.DISABLED, background="#f8f8f8", + font=("Courier", 9)) + config_scroll = ttk.Scrollbar(text_frame, orient="vertical", command=self.config_text.yview) + self.config_text.configure(yscrollcommand=config_scroll.set) + + self.config_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + config_scroll.pack(side=tk.RIGHT, fill=tk.Y) + + # Buttons + button_frame = ttk.Frame(frame) + button_frame.pack(fill=tk.X, pady=10) + + ttk.Button(button_frame, text="Copy to Clipboard", + command=self._copy_config).pack(side=tk.LEFT, padx=5) + ttk.Button(button_frame, text="Save to File", + command=self._save_claude_config).pack(side=tk.LEFT, padx=5) + ttk.Button(button_frame, text="Refresh", + command=self._update_claude_config).pack(side=tk.LEFT, padx=5) + + self._update_claude_config() + + def _load_values(self): + """Load configuration values into UI.""" + self.server_name_var.set(self.config.get('server.name', 'llm-train-models')) + self.auto_start_var.set(self.config.get('server.auto_start', False)) + self.log_level_var.set(self.config.get('logging.level', 'INFO')) + self.log_file_var.set(self.config.get('logging.file', 'mcp_server.log')) + + self.cache_size_var.set(self.config.get('models.cache_size', 3)) + self.default_ctx_var.set(self.config.get('models.default_context', 4096)) + self.default_gpu_var.set(self.config.get('models.default_gpu_layers', 0)) + + def _save_config(self): + """Save configuration.""" + self.config.set('server.name', self.server_name_var.get()) + self.config.set('server.auto_start', self.auto_start_var.get()) + self.config.set('logging.level', self.log_level_var.get()) + self.config.set('logging.file', self.log_file_var.get()) + + self.config.set('models.cache_size', self.cache_size_var.get()) + self.config.set('models.default_context', self.default_ctx_var.get()) + self.config.set('models.default_gpu_layers', self.default_gpu_var.get()) + + if self.config.save_config(): + messagebox.showinfo("Success", "Configuration saved successfully!") + self.dialog.destroy() + else: + messagebox.showerror("Error", "Failed to save configuration") + + def _browse_log_file(self): + """Browse for log file location.""" + filename = filedialog.asksaveasfilename( + title="Select Log File", + defaultextension=".log", + filetypes=[("Log files", "*.log"), ("All files", "*.*")] + ) + if filename: + self.log_file_var.set(filename) + + def _update_claude_config(self): + """Update the Claude configuration display.""" + config = self.config.generate_claude_config() + config_json = json.dumps(config, indent=2) + + self.config_text.delete(1.0, tk.END) + self.config_text.insert(1.0, config_json) + + def _copy_config(self): + """Copy configuration to clipboard.""" + config_text = self.config_text.get(1.0, tk.END) + self.dialog.clipboard_clear() + self.dialog.clipboard_append(config_text) + messagebox.showinfo("Copied", "Configuration copied to clipboard!") + + def _save_claude_config(self): + """Save Claude configuration to file.""" + filename = filedialog.asksaveasfilename( + title="Save Claude Config", + defaultextension=".json", + filetypes=[("JSON files", "*.json"), ("All files", "*.*")], + initialvalue="claude_desktop_config.json" + ) + + if filename: + try: + config = self.config.generate_claude_config() + with open(filename, 'w') as f: + json.dump(config, f, indent=2) + messagebox.showinfo("Saved", f"Configuration saved to {filename}") + except Exception as e: + messagebox.showerror("Error", f"Failed to save file: {e}") + + def _test_server(self): + """Test the MCP server.""" + try: + # Try to run the server with --help to test if it's working + result = subprocess.run([ + sys.executable, "mcp_server.py", "--help" + ], capture_output=True, text=True, timeout=10) + + if result.returncode == 0: + messagebox.showinfo("Test Successful", + "MCP server appears to be working correctly!") + else: + messagebox.showerror("Test Failed", + f"Server test failed:\\n{result.stderr}") + except Exception as e: + messagebox.showerror("Test Error", f"Failed to test server: {e}") + + +def open_mcp_config(parent: tk.Tk): + """Open the MCP configuration dialog.""" + MCPConfigGUI(parent) + + +if __name__ == "__main__": + # Test the configuration + root = tk.Tk() + root.withdraw() # Hide main window + + app = MCPConfigGUI(root) + root.mainloop() \ No newline at end of file diff --git a/mcp_server.py b/mcp_server.py new file mode 100644 index 0000000..e9bbaa1 --- /dev/null +++ b/mcp_server.py @@ -0,0 +1,718 @@ +#!/usr/bin/env python3 +""" +Multi-Model MCP Server for LLM_Train + +This server provides MCP (Model Context Protocol) access to multiple local models +managed by the LLM_Train application. It supports model discovery, switching, +and inference through a standardized interface. +""" + +import asyncio +import json +import logging +import os +import sys +import platform +import subprocess +from typing import Any, Dict, List, Optional, Sequence +import argparse +from pathlib import Path + +try: + from mcp.server import Server + from mcp.server.models import InitializationOptions + from mcp.server.stdio import stdio_server + from mcp.types import ( + CallToolRequestParams, + GetPromptRequestParams, + ListPromptsRequestParams, + ListToolsRequestParams, + Prompt, + PromptMessage, + Resource, + TextContent, + Tool, + EmbeddedResource, + ) +except ImportError: + print("MCP library not found. Install with: pip install mcp", file=sys.stderr) + sys.exit(1) + +# Import our local modules +try: + from model_library import ModelLibrary, ModelInfo + from settings_manager import SettingsManager + from llama_cpp import Llama +except ImportError as e: + print(f"Required modules not found: {e}", file=sys.stderr) + sys.exit(1) + + +class MultiModelMCPServer: + """MCP Server for managing multiple local models with CUDA support.""" + + def __init__(self, settings_path: str = "settings.json"): + self.settings = SettingsManager(settings_path) + self.library = None + self.current_model = None + self.current_llm = None + self.model_cache = {} # Cache for loaded models + + # Initialize model library if configured + library_root = self.settings.get('library.root_folder', '') + if library_root and os.path.exists(library_root): + max_depth = self.settings.get('library.max_depth', 3) + self.library = ModelLibrary(library_root, max_depth) + self.library._load_index() + + # Setup logging + logging.basicConfig(level=logging.INFO) + self.logger = logging.getLogger(__name__) + + # Detect system capabilities + self.system_info = self._detect_system_capabilities() + self.logger.info(f"System capabilities: {self.system_info}") + + def _detect_system_capabilities(self) -> Dict[str, Any]: + """Detect system capabilities including CUDA, ROCm, and Metal support.""" + capabilities = { + "platform": platform.system(), + "architecture": platform.machine(), + "cuda_available": False, + "cuda_version": None, + "cuda_devices": 0, + "rocm_available": False, + "metal_available": False, + "intel_gpu_available": False, + "recommended_layers": 0 + } + + try: + # Check for CUDA (NVIDIA) + if capabilities["platform"] in ["Windows", "Linux"]: + try: + # Try nvidia-smi command + result = subprocess.run( + ["nvidia-smi", "--query-gpu=count,driver_version", "--format=csv,noheader,nounits"], + capture_output=True, text=True, timeout=10 + ) + if result.returncode == 0: + lines = result.stdout.strip().split('\n') + if lines and lines[0]: + parts = lines[0].split(', ') + if len(parts) >= 2: + capabilities["cuda_devices"] = len(lines) + capabilities["cuda_version"] = parts[1] + capabilities["cuda_available"] = True + + # Recommend using most GPU layers for CUDA + capabilities["recommended_layers"] = 35 # Good default for most models + self.logger.info(f"CUDA detected: {capabilities['cuda_devices']} device(s), driver {capabilities['cuda_version']}") + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): + pass + + # Check for Intel GPU on Windows (Arc/Iris Xe) + if capabilities["platform"] == "Windows": + try: + result = subprocess.run( + ["wmic", "path", "win32_VideoController", "get", "name"], + capture_output=True, text=True, timeout=10 + ) + if result.returncode == 0 and "intel" in result.stdout.lower(): + capabilities["intel_gpu_available"] = True + if not capabilities["cuda_available"]: + capabilities["recommended_layers"] = 15 # Conservative for Intel GPU + self.logger.info("Intel GPU detected") + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): + pass + + # Check for ROCm (AMD) on Linux + elif capabilities["platform"] == "Linux": + try: + result = subprocess.run( + ["rocm-smi", "--showproductname"], + capture_output=True, text=True, timeout=10 + ) + if result.returncode == 0: + capabilities["rocm_available"] = True + capabilities["recommended_layers"] = 25 # Good default for ROCm + self.logger.info("ROCm (AMD GPU) detected") + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): + pass + + # Check for Metal (Apple Silicon) on macOS + elif capabilities["platform"] == "Darwin": + # Check if running on Apple Silicon + if "arm" in capabilities["architecture"].lower() or "apple" in platform.processor().lower(): + capabilities["metal_available"] = True + capabilities["recommended_layers"] = 30 # Good default for Apple Silicon + self.logger.info("Apple Silicon (Metal) detected") + + except Exception as e: + self.logger.warning(f"Error detecting system capabilities: {e}") + + return capabilities + + def get_available_models(self) -> List[ModelInfo]: + """Get list of available models.""" + if self.library: + return list(self.library.models.values()) + return [] + + def load_model(self, model_path: str, **kwargs) -> bool: + """Load a model for inference with optimized GPU acceleration.""" + try: + # Check if model is already loaded + if model_path in self.model_cache: + self.current_model = model_path + self.current_llm = self.model_cache[model_path] + return True + + # Get parameters with smart defaults based on system capabilities + n_ctx = kwargs.get('n_ctx', 4096) + n_threads = kwargs.get('n_threads', min(os.cpu_count() or 4, 8)) # Cap threads for stability + + # Smart GPU layer detection + n_gpu_layers = kwargs.get('n_gpu_layers') + if n_gpu_layers is None: + # Auto-detect optimal GPU layers + if self.system_info["cuda_available"]: + n_gpu_layers = self.system_info["recommended_layers"] + elif self.system_info["rocm_available"]: + n_gpu_layers = self.system_info["recommended_layers"] + elif self.system_info["metal_available"]: + n_gpu_layers = self.system_info["recommended_layers"] + elif self.system_info["intel_gpu_available"]: + n_gpu_layers = self.system_info["recommended_layers"] + else: + n_gpu_layers = 0 # CPU only + + # Additional optimizations based on platform + llm_kwargs = { + "model_path": model_path, + "n_ctx": n_ctx, + "n_gpu_layers": n_gpu_layers, + "n_threads": n_threads, + "verbose": False + } + + # Platform-specific optimizations + if self.system_info["platform"] == "Windows": + # Windows optimizations + if self.system_info["cuda_available"]: + llm_kwargs["n_batch"] = 512 # Good batch size for CUDA on Windows + elif self.system_info["intel_gpu_available"]: + llm_kwargs["n_batch"] = 256 # Conservative for Intel GPU + + elif self.system_info["platform"] == "Linux": + # Linux optimizations + if self.system_info["cuda_available"]: + llm_kwargs["n_batch"] = 512 + llm_kwargs["use_mmap"] = True # Better memory management on Linux + elif self.system_info["rocm_available"]: + llm_kwargs["n_batch"] = 256 # Conservative for ROCm + + elif self.system_info["platform"] == "Darwin": + # macOS optimizations + if self.system_info["metal_available"]: + llm_kwargs["n_batch"] = 512 + llm_kwargs["use_mmap"] = True + + self.logger.info(f"Loading model with: {n_gpu_layers} GPU layers, {n_threads} threads") + + llm = Llama(**llm_kwargs) + + # Cache the model (limit cache size) + if len(self.model_cache) >= 3: # Max 3 models in cache + # Remove oldest model + oldest_key = next(iter(self.model_cache)) + del self.model_cache[oldest_key] + + self.model_cache[model_path] = llm + self.current_model = model_path + self.current_llm = llm + + self.logger.info(f"Successfully loaded model: {model_path}") + if n_gpu_layers > 0: + acceleration = "CUDA" if self.system_info["cuda_available"] else \ + "ROCm" if self.system_info["rocm_available"] else \ + "Metal" if self.system_info["metal_available"] else \ + "Intel GPU" if self.system_info["intel_gpu_available"] else "GPU" + self.logger.info(f"Using {acceleration} acceleration with {n_gpu_layers} layers") + + return True + + except Exception as e: + self.logger.error(f"Failed to load model {model_path}: {e}") + return False + + def generate_text(self, prompt: str, **kwargs) -> Dict[str, Any]: + """Generate text using the current model.""" + if not self.current_llm: + return {"error": "No model loaded"} + + try: + max_tokens = kwargs.get('max_tokens', 256) + temperature = kwargs.get('temperature', 0.7) + stream = kwargs.get('stream', False) + + if stream: + # For MCP, we'll collect the stream and return the full result + result = "" + for chunk in self.current_llm.create_completion( + prompt, + max_tokens=max_tokens, + temperature=temperature, + stream=True + ): + if 'choices' in chunk and chunk['choices']: + text = chunk['choices'][0].get('text', '') + result += text + + return { + "text": result, + "model": self.current_model, + "tokens": len(result.split()) + } + else: + response = self.current_llm.create_completion( + prompt, + max_tokens=max_tokens, + temperature=temperature + ) + + return { + "text": response['choices'][0]['text'], + "model": self.current_model, + "tokens": response['usage']['total_tokens'] + } + + except Exception as e: + self.logger.error(f"Generation failed: {e}") + return {"error": str(e)} + + def chat_completion(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]: + """Generate chat completion using the current model.""" + if not self.current_llm: + return {"error": "No model loaded"} + + try: + max_tokens = kwargs.get('max_tokens', 256) + temperature = kwargs.get('temperature', 0.7) + + response = self.current_llm.create_chat_completion( + messages=messages, + max_tokens=max_tokens, + temperature=temperature + ) + + return { + "message": response['choices'][0]['message'], + "model": self.current_model, + "tokens": response['usage']['total_tokens'] + } + + except Exception as e: + self.logger.error(f"Chat completion failed: {e}") + return {"error": str(e)} + + +# Global server instance +mcp_server = MultiModelMCPServer() + +# Create MCP server +server = Server("llm-train-models") + + +@server.list_tools() +async def handle_list_tools() -> List[Tool]: + """List available tools.""" + return [ + Tool( + name="list_models", + description="List all available models in the library", + inputSchema={ + "type": "object", + "properties": {}, + "required": [] + } + ), + Tool( + name="load_model", + description="Load a specific model for inference", + inputSchema={ + "type": "object", + "properties": { + "model_path": { + "type": "string", + "description": "Path to the model file" + }, + "n_ctx": { + "type": "integer", + "description": "Context window size", + "default": 4096 + }, + "n_gpu_layers": { + "type": "integer", + "description": "Number of GPU layers", + "default": 0 + } + }, + "required": ["model_path"] + } + ), + Tool( + name="generate_text", + description="Generate text using the current model", + inputSchema={ + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The text prompt for generation" + }, + "max_tokens": { + "type": "integer", + "description": "Maximum tokens to generate", + "default": 256 + }, + "temperature": { + "type": "number", + "description": "Sampling temperature", + "default": 0.7 + } + }, + "required": ["prompt"] + } + ), + Tool( + name="chat_completion", + description="Generate chat completion using the current model", + inputSchema={ + "type": "object", + "properties": { + "messages": { + "type": "array", + "description": "Chat messages", + "items": { + "type": "object", + "properties": { + "role": {"type": "string"}, + "content": {"type": "string"} + }, + "required": ["role", "content"] + } + }, + "max_tokens": { + "type": "integer", + "description": "Maximum tokens to generate", + "default": 256 + }, + "temperature": { + "type": "number", + "description": "Sampling temperature", + "default": 0.7 + } + }, + "required": ["messages"] + } + ), + Tool( + name="get_current_model", + description="Get information about the currently loaded model", + inputSchema={ + "type": "object", + "properties": {}, + "required": [] + } + ), + Tool( + name="get_system_info", + description="Get system capabilities and GPU acceleration status", + inputSchema={ + "type": "object", + "properties": {}, + "required": [] + } + ) + ] + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: + """Handle tool calls.""" + + if name == "list_models": + models = mcp_server.get_available_models() + model_list = [] + + for model in models: + model_info = { + "name": model.name, + "path": model.path, + "type": model.file_type, + "size_mb": round(model.size_mb, 1), + "modified": model.modified_date, + "tags": model.tags, + "metadata": model.metadata + } + model_list.append(model_info) + + return [TextContent( + type="text", + text=json.dumps(model_list, indent=2) + )] + + elif name == "load_model": + model_path = arguments.get("model_path") + n_ctx = arguments.get("n_ctx", 4096) + n_gpu_layers = arguments.get("n_gpu_layers", 0) + + if not model_path: + return [TextContent(type="text", text="Error: model_path is required")] + + success = mcp_server.load_model( + model_path, + n_ctx=n_ctx, + n_gpu_layers=n_gpu_layers + ) + + if success: + return [TextContent( + type="text", + text=f"Successfully loaded model: {model_path}" + )] + else: + return [TextContent( + type="text", + text=f"Failed to load model: {model_path}" + )] + + elif name == "generate_text": + prompt = arguments.get("prompt") + max_tokens = arguments.get("max_tokens", 256) + temperature = arguments.get("temperature", 0.7) + + if not prompt: + return [TextContent(type="text", text="Error: prompt is required")] + + result = mcp_server.generate_text( + prompt, + max_tokens=max_tokens, + temperature=temperature + ) + + return [TextContent( + type="text", + text=json.dumps(result, indent=2) + )] + + elif name == "chat_completion": + messages = arguments.get("messages", []) + max_tokens = arguments.get("max_tokens", 256) + temperature = arguments.get("temperature", 0.7) + + if not messages: + return [TextContent(type="text", text="Error: messages are required")] + + result = mcp_server.chat_completion( + messages, + max_tokens=max_tokens, + temperature=temperature + ) + + return [TextContent( + type="text", + text=json.dumps(result, indent=2) + )] + + elif name == "get_current_model": + if mcp_server.current_model: + # Find model info + models = mcp_server.get_available_models() + current_info = None + + for model in models: + if model.path == mcp_server.current_model: + current_info = { + "name": model.name, + "path": model.path, + "type": model.file_type, + "size_mb": round(model.size_mb, 1), + "metadata": model.metadata + } + break + + if current_info: + return [TextContent( + type="text", + text=json.dumps(current_info, indent=2) + )] + + return [TextContent(type="text", text="No model currently loaded")] + + elif name == "get_system_info": + system_info = { + "platform": mcp_server.system_info["platform"], + "architecture": mcp_server.system_info["architecture"], + "acceleration": { + "cuda_available": mcp_server.system_info["cuda_available"], + "cuda_version": mcp_server.system_info["cuda_version"], + "cuda_devices": mcp_server.system_info["cuda_devices"], + "rocm_available": mcp_server.system_info["rocm_available"], + "metal_available": mcp_server.system_info["metal_available"], + "intel_gpu_available": mcp_server.system_info["intel_gpu_available"], + "recommended_layers": mcp_server.system_info["recommended_layers"] + }, + "current_model_acceleration": "Unknown" + } + + # Add current model acceleration info + if mcp_server.current_llm: + if mcp_server.system_info["cuda_available"]: + system_info["current_model_acceleration"] = "CUDA (NVIDIA)" + elif mcp_server.system_info["rocm_available"]: + system_info["current_model_acceleration"] = "ROCm (AMD)" + elif mcp_server.system_info["metal_available"]: + system_info["current_model_acceleration"] = "Metal (Apple)" + elif mcp_server.system_info["intel_gpu_available"]: + system_info["current_model_acceleration"] = "Intel GPU" + else: + system_info["current_model_acceleration"] = "CPU Only" + + return [TextContent( + type="text", + text=json.dumps(system_info, indent=2) + )] + + else: + return [TextContent(type="text", text=f"Unknown tool: {name}")] + + +@server.list_prompts() +async def handle_list_prompts() -> List[Prompt]: + """List available prompts.""" + return [ + Prompt( + name="model_comparison", + description="Compare multiple models on the same prompt", + arguments=[ + { + "name": "prompt", + "description": "The prompt to test across models", + "required": True + }, + { + "name": "models", + "description": "List of model paths to compare", + "required": True + } + ] + ), + Prompt( + name="model_benchmark", + description="Benchmark a model with standard prompts", + arguments=[ + { + "name": "model_path", + "description": "Path to the model to benchmark", + "required": True + } + ] + ) + ] + + +@server.get_prompt() +async def handle_get_prompt(name: str, arguments: Dict[str, str]) -> Prompt: + """Handle prompt requests.""" + + if name == "model_comparison": + prompt_text = arguments.get("prompt", "") + models = arguments.get("models", "").split(",") + + messages = [ + PromptMessage( + role="user", + content=TextContent( + type="text", + text=f"Compare the following models on this prompt: '{prompt_text}'\n\n" + f"Models to test: {', '.join(models)}\n\n" + f"For each model, load it and generate a response, then provide a comparison." + ) + ) + ] + + return Prompt( + name=name, + description="Compare multiple models", + messages=messages + ) + + elif name == "model_benchmark": + model_path = arguments.get("model_path", "") + + benchmark_prompts = [ + "Explain quantum computing in simple terms.", + "Write a Python function to calculate fibonacci numbers.", + "What are the main causes of climate change?", + "Describe the process of photosynthesis.", + "Write a short story about a robot learning to paint." + ] + + messages = [ + PromptMessage( + role="user", + content=TextContent( + type="text", + text=f"Benchmark the model at: {model_path}\n\n" + f"Test it with these prompts:\n" + + "\n".join(f"{i+1}. {p}" for i, p in enumerate(benchmark_prompts)) + + "\n\nProvide the model's response to each prompt and evaluate quality." + ) + ) + ] + + return Prompt( + name=name, + description="Benchmark model performance", + messages=messages + ) + + else: + raise ValueError(f"Unknown prompt: {name}") + + +async def main(): + """Main entry point.""" + parser = argparse.ArgumentParser(description="Multi-Model MCP Server") + parser.add_argument( + "--settings", + default="settings.json", + help="Path to settings file" + ) + + args = parser.parse_args() + + # Initialize server with settings + global mcp_server + mcp_server = MultiModelMCPServer(args.settings) + + # Run the server + async with stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="llm-train-models", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=None, + experimental_capabilities=None + ) + ) + ) + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/mcp_tab.py b/mcp_tab.py new file mode 100644 index 0000000..2b024a3 --- /dev/null +++ b/mcp_tab.py @@ -0,0 +1,554 @@ +#!/usr/bin/env python3 +""" +MCP Server Management Tab for DarkHal 2.0 + +Provides a comprehensive interface for managing and monitoring the MCP server, +including status monitoring, control, and configuration. +""" + +import tkinter as tk +from tkinter import ttk, messagebox, scrolledtext +import subprocess +import threading +import json +import os +import sys +import time +import queue +from pathlib import Path +from datetime import datetime +from typing import Optional, Dict, Any, List + + +class MCPServerManager: + """Manages the MCP server process and communication.""" + + def __init__(self): + self.process = None + self.status = "stopped" + self.start_time = None + self.message_queue = queue.Queue() + self.callbacks = { + 'on_start': [], + 'on_stop': [], + 'on_error': [], + 'on_message': [] + } + + def register_callback(self, event: str, callback): + """Register a callback for server events.""" + if event in self.callbacks: + self.callbacks[event].append(callback) + + def _trigger_callbacks(self, event: str, *args, **kwargs): + """Trigger all callbacks for an event.""" + for callback in self.callbacks.get(event, []): + try: + callback(*args, **kwargs) + except Exception as e: + print(f"Callback error: {e}") + + def start_server(self, config: Dict[str, Any] = None) -> bool: + """Start the MCP server.""" + if self.process and self.process.poll() is None: + return False # Already running + + try: + # Prepare environment + env = os.environ.copy() + if config: + env['MCP_CONFIG'] = json.dumps(config) + + # Start server process + server_path = Path(__file__).parent / "mcp_server.py" + self.process = subprocess.Popen( + [sys.executable, str(server_path)], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + env=env + ) + + self.status = "running" + self.start_time = time.time() + + # Start output monitoring threads + threading.Thread(target=self._monitor_stdout, daemon=True).start() + threading.Thread(target=self._monitor_stderr, daemon=True).start() + + self._trigger_callbacks('on_start') + return True + + except Exception as e: + self.status = "error" + self._trigger_callbacks('on_error', str(e)) + return False + + def stop_server(self) -> bool: + """Stop the MCP server.""" + if not self.process: + return False + + try: + # Send shutdown signal + if self.process.poll() is None: + self.process.terminate() + + # Wait for graceful shutdown + try: + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.process.kill() + self.process.wait() + + self.process = None + self.status = "stopped" + self.start_time = None + self._trigger_callbacks('on_stop') + return True + + except Exception as e: + self._trigger_callbacks('on_error', str(e)) + return False + + def restart_server(self, config: Dict[str, Any] = None) -> bool: + """Restart the MCP server.""" + self.stop_server() + time.sleep(1) # Brief pause + return self.start_server(config) + + def send_command(self, command: Dict[str, Any]) -> bool: + """Send a command to the server.""" + if not self.process or self.process.poll() is not None: + return False + + try: + command_json = json.dumps(command) + '\n' + self.process.stdin.write(command_json) + self.process.stdin.flush() + return True + except Exception: + return False + + def _monitor_stdout(self): + """Monitor stdout from the server.""" + while self.process and self.process.poll() is None: + try: + line = self.process.stdout.readline() + if line: + self._trigger_callbacks('on_message', 'stdout', line.strip()) + except Exception: + break + + def _monitor_stderr(self): + """Monitor stderr from the server.""" + while self.process and self.process.poll() is None: + try: + line = self.process.stderr.readline() + if line: + self._trigger_callbacks('on_message', 'stderr', line.strip()) + except Exception: + break + + def get_status(self) -> Dict[str, Any]: + """Get server status information.""" + return { + 'status': self.status, + 'running': self.process and self.process.poll() is None, + 'pid': self.process.pid if self.process else None, + 'uptime': time.time() - self.start_time if self.start_time else 0 + } + + +class MCPTab: + """MCP Server management tab for DarkHal 2.0.""" + + def __init__(self, parent: ttk.Frame, settings_manager): + self.parent = parent + self.settings = settings_manager + self.server_manager = MCPServerManager() + self.tools_info = [] + self.server_config = {} + + # Register callbacks + self.server_manager.register_callback('on_start', self._on_server_start) + self.server_manager.register_callback('on_stop', self._on_server_stop) + self.server_manager.register_callback('on_error', self._on_server_error) + self.server_manager.register_callback('on_message', self._on_server_message) + + self._build_ui() + self._load_config() + self._update_status() + + def _build_ui(self): + """Build the MCP tab UI.""" + # Main container + main_frame = ttk.Frame(self.parent) + main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Left panel - Control and Status + left_panel = ttk.Frame(main_frame) + left_panel.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + + # Server Status Frame + status_frame = ttk.LabelFrame(left_panel, text="Server Status", padding=10) + status_frame.pack(fill=tk.X, pady=(0, 10)) + + # Status indicators + status_grid = ttk.Frame(status_frame) + status_grid.pack(fill=tk.X) + + # Status light + self.status_canvas = tk.Canvas(status_grid, width=20, height=20) + self.status_canvas.grid(row=0, column=0, padx=(0, 10)) + self.status_indicator = self.status_canvas.create_oval(2, 2, 18, 18, fill="red") + + ttk.Label(status_grid, text="Status:").grid(row=0, column=1, sticky=tk.W) + self.status_label = ttk.Label(status_grid, text="Stopped", font=("TkDefaultFont", 10, "bold")) + self.status_label.grid(row=0, column=2, sticky=tk.W, padx=(5, 20)) + + ttk.Label(status_grid, text="PID:").grid(row=0, column=3, sticky=tk.W) + self.pid_label = ttk.Label(status_grid, text="N/A") + self.pid_label.grid(row=0, column=4, sticky=tk.W, padx=(5, 20)) + + ttk.Label(status_grid, text="Uptime:").grid(row=1, column=1, sticky=tk.W, pady=(5, 0)) + self.uptime_label = ttk.Label(status_grid, text="00:00:00") + self.uptime_label.grid(row=1, column=2, sticky=tk.W, padx=(5, 20), pady=(5, 0)) + + ttk.Label(status_grid, text="Port:").grid(row=1, column=3, sticky=tk.W, pady=(5, 0)) + self.port_label = ttk.Label(status_grid, text="N/A") + self.port_label.grid(row=1, column=4, sticky=tk.W, padx=(5, 20), pady=(5, 0)) + + # Control Buttons Frame + control_frame = ttk.LabelFrame(left_panel, text="Server Control", padding=10) + control_frame.pack(fill=tk.X, pady=(0, 10)) + + button_frame = ttk.Frame(control_frame) + button_frame.pack(fill=tk.X) + + self.start_btn = ttk.Button(button_frame, text="Start Server", command=self._start_server) + self.start_btn.pack(side=tk.LEFT, padx=2) + + self.stop_btn = ttk.Button(button_frame, text="Stop Server", command=self._stop_server, state=tk.DISABLED) + self.stop_btn.pack(side=tk.LEFT, padx=2) + + self.restart_btn = ttk.Button(button_frame, text="Restart Server", command=self._restart_server, state=tk.DISABLED) + self.restart_btn.pack(side=tk.LEFT, padx=2) + + ttk.Button(button_frame, text="Configure", command=self._open_config).pack(side=tk.LEFT, padx=(20, 2)) + ttk.Button(button_frame, text="Test Connection", command=self._test_connection).pack(side=tk.LEFT, padx=2) + + # Configuration Frame + config_frame = ttk.LabelFrame(left_panel, text="Configuration", padding=10) + config_frame.pack(fill=tk.X, pady=(0, 10)) + + config_grid = ttk.Frame(config_frame) + config_grid.pack(fill=tk.X) + + ttk.Label(config_grid, text="Model Cache:").grid(row=0, column=0, sticky=tk.W, pady=2) + self.cache_var = tk.StringVar(value="3 models") + ttk.Label(config_grid, textvariable=self.cache_var).grid(row=0, column=1, sticky=tk.W, padx=(10, 0)) + + ttk.Label(config_grid, text="Default Context:").grid(row=1, column=0, sticky=tk.W, pady=2) + self.ctx_var = tk.StringVar(value="4096 tokens") + ttk.Label(config_grid, textvariable=self.ctx_var).grid(row=1, column=1, sticky=tk.W, padx=(10, 0)) + + ttk.Label(config_grid, text="GPU Layers:").grid(row=2, column=0, sticky=tk.W, pady=2) + self.gpu_var = tk.StringVar(value="Auto") + ttk.Label(config_grid, textvariable=self.gpu_var).grid(row=2, column=1, sticky=tk.W, padx=(10, 0)) + + ttk.Label(config_grid, text="Log Level:").grid(row=3, column=0, sticky=tk.W, pady=2) + self.log_var = tk.StringVar(value="INFO") + ttk.Label(config_grid, textvariable=self.log_var).grid(row=3, column=1, sticky=tk.W, padx=(10, 0)) + + # Tools Frame + tools_frame = ttk.LabelFrame(left_panel, text="Available Tools", padding=10) + tools_frame.pack(fill=tk.BOTH, expand=True) + + # Tools list + columns = ("tool", "description", "status") + self.tools_tree = ttk.Treeview(tools_frame, columns=columns, show="headings", height=8) + + self.tools_tree.heading("tool", text="Tool") + self.tools_tree.heading("description", text="Description") + self.tools_tree.heading("status", text="Status") + + self.tools_tree.column("tool", width=150) + self.tools_tree.column("description", width=300) + self.tools_tree.column("status", width=80) + + tools_scroll = ttk.Scrollbar(tools_frame, orient="vertical", command=self.tools_tree.yview) + self.tools_tree.configure(yscrollcommand=tools_scroll.set) + + self.tools_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + tools_scroll.pack(side=tk.RIGHT, fill=tk.Y) + + # Populate default tools + self._populate_tools() + + # Right panel - Logs + right_panel = ttk.Frame(main_frame) + right_panel.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=(10, 0)) + + # Server Logs Frame + log_frame = ttk.LabelFrame(right_panel, text="Server Logs", padding=10) + log_frame.pack(fill=tk.BOTH, expand=True) + + # Log display + self.log_text = scrolledtext.ScrolledText(log_frame, height=20, wrap=tk.WORD, + bg="#1a1a1a", fg="#00ff88", + font=("Consolas", 9)) + self.log_text.pack(fill=tk.BOTH, expand=True) + + # Configure log tags + self.log_text.tag_configure("error", foreground="#ff4444") + self.log_text.tag_configure("warning", foreground="#ffaa00") + self.log_text.tag_configure("info", foreground="#00aaff") + self.log_text.tag_configure("success", foreground="#00ff88") + + # Log controls + log_controls = ttk.Frame(log_frame) + log_controls.pack(fill=tk.X, pady=(5, 0)) + + ttk.Button(log_controls, text="Clear Logs", command=self._clear_logs).pack(side=tk.LEFT) + ttk.Button(log_controls, text="Save Logs", command=self._save_logs).pack(side=tk.LEFT, padx=5) + + self.autoscroll_var = tk.BooleanVar(value=True) + ttk.Checkbutton(log_controls, text="Auto-scroll", variable=self.autoscroll_var).pack(side=tk.LEFT, padx=10) + + # Performance Metrics Frame + metrics_frame = ttk.LabelFrame(right_panel, text="Performance Metrics", padding=10) + metrics_frame.pack(fill=tk.X, pady=(10, 0)) + + metrics_grid = ttk.Frame(metrics_frame) + metrics_grid.pack(fill=tk.X) + + ttk.Label(metrics_grid, text="Requests/sec:").grid(row=0, column=0, sticky=tk.W, pady=2) + self.req_rate_label = ttk.Label(metrics_grid, text="0") + self.req_rate_label.grid(row=0, column=1, sticky=tk.W, padx=(10, 30)) + + ttk.Label(metrics_grid, text="Avg Response:").grid(row=0, column=2, sticky=tk.W, pady=2) + self.response_time_label = ttk.Label(metrics_grid, text="0ms") + self.response_time_label.grid(row=0, column=3, sticky=tk.W, padx=(10, 0)) + + ttk.Label(metrics_grid, text="Total Requests:").grid(row=1, column=0, sticky=tk.W, pady=2) + self.total_requests_label = ttk.Label(metrics_grid, text="0") + self.total_requests_label.grid(row=1, column=1, sticky=tk.W, padx=(10, 30)) + + ttk.Label(metrics_grid, text="Errors:").grid(row=1, column=2, sticky=tk.W, pady=2) + self.errors_label = ttk.Label(metrics_grid, text="0") + self.errors_label.grid(row=1, column=3, sticky=tk.W, padx=(10, 0)) + + # Start status update timer + self.parent.after(1000, self._update_status) + + def _populate_tools(self): + """Populate the tools list with available MCP tools.""" + tools = [ + ("list_models", "List all available models in the library", "Ready"), + ("load_model", "Load a model with specified parameters", "Ready"), + ("unload_model", "Unload the current model from memory", "Ready"), + ("generate_text", "Generate text using the loaded model", "Ready"), + ("chat", "Interactive chat with context management", "Ready"), + ("get_system_info", "Get system and GPU information", "Ready"), + ("get_model_info", "Get information about loaded model", "Ready"), + ("list_loras", "List available LoRA adapters", "Ready"), + ("apply_lora", "Apply a LoRA adapter to the model", "Ready"), + ("benchmark", "Run performance benchmarks", "Ready") + ] + + for tool in tools: + self.tools_tree.insert("", tk.END, values=tool) + + def _load_config(self): + """Load MCP configuration.""" + try: + config_file = Path("mcp_config.json") + if config_file.exists(): + with open(config_file, 'r') as f: + self.server_config = json.load(f) + + # Update display + cache = self.server_config.get('models', {}).get('cache_size', 3) + self.cache_var.set(f"{cache} models") + + ctx = self.server_config.get('models', {}).get('default_context', 4096) + self.ctx_var.set(f"{ctx} tokens") + + gpu = self.server_config.get('models', {}).get('default_gpu_layers', 0) + self.gpu_var.set("Auto" if gpu == 0 else f"{gpu} layers") + + log_level = self.server_config.get('logging', {}).get('level', 'INFO') + self.log_var.set(log_level) + + except Exception as e: + self._log(f"Error loading config: {e}", "error") + + def _start_server(self): + """Start the MCP server.""" + self._log("Starting MCP server...", "info") + + if self.server_manager.start_server(self.server_config): + self._log("MCP server started successfully", "success") + else: + self._log("Failed to start MCP server", "error") + + def _stop_server(self): + """Stop the MCP server.""" + self._log("Stopping MCP server...", "info") + + if self.server_manager.stop_server(): + self._log("MCP server stopped", "info") + else: + self._log("Failed to stop MCP server", "error") + + def _restart_server(self): + """Restart the MCP server.""" + self._log("Restarting MCP server...", "info") + + if self.server_manager.restart_server(self.server_config): + self._log("MCP server restarted successfully", "success") + else: + self._log("Failed to restart MCP server", "error") + + def _test_connection(self): + """Test the server connection.""" + if not self.server_manager.get_status()['running']: + messagebox.showinfo("Not Running", "Server is not running. Start it first.") + return + + # Send a test command + test_cmd = { + "jsonrpc": "2.0", + "method": "tools/list", + "id": 1 + } + + if self.server_manager.send_command(test_cmd): + self._log("Connection test sent", "info") + else: + self._log("Connection test failed", "error") + + def _open_config(self): + """Open the configuration dialog.""" + from mcp_config import open_mcp_config + open_mcp_config(self.parent.winfo_toplevel()) + self._load_config() # Reload after config changes + + def _clear_logs(self): + """Clear the log display.""" + self.log_text.delete(1.0, tk.END) + + def _save_logs(self): + """Save logs to file.""" + from tkinter import filedialog + + filename = filedialog.asksaveasfilename( + title="Save Logs", + defaultextension=".log", + filetypes=[("Log files", "*.log"), ("Text files", "*.txt"), ("All files", "*.*")] + ) + + if filename: + try: + with open(filename, 'w') as f: + f.write(self.log_text.get(1.0, tk.END)) + self._log(f"Logs saved to {filename}", "success") + except Exception as e: + self._log(f"Error saving logs: {e}", "error") + + def _log(self, message: str, level: str = "info"): + """Add a message to the log display.""" + timestamp = datetime.now().strftime("%H:%M:%S") + log_entry = f"[{timestamp}] {message}\n" + + # Determine tag based on level + tag = level.lower() + if tag not in ["error", "warning", "info", "success"]: + tag = "info" + + # Insert with tag + self.log_text.insert(tk.END, log_entry, tag) + + # Auto-scroll if enabled + if self.autoscroll_var.get(): + self.log_text.see(tk.END) + + def _on_server_start(self): + """Handle server start event.""" + self.start_btn.config(state=tk.DISABLED) + self.stop_btn.config(state=tk.NORMAL) + self.restart_btn.config(state=tk.NORMAL) + self.status_canvas.itemconfig(self.status_indicator, fill="#00ff88") + self.status_label.config(text="Running") + + # Update tool status + for item in self.tools_tree.get_children(): + self.tools_tree.set(item, "status", "Active") + + def _on_server_stop(self): + """Handle server stop event.""" + self.start_btn.config(state=tk.NORMAL) + self.stop_btn.config(state=tk.DISABLED) + self.restart_btn.config(state=tk.DISABLED) + self.status_canvas.itemconfig(self.status_indicator, fill="red") + self.status_label.config(text="Stopped") + self.pid_label.config(text="N/A") + self.uptime_label.config(text="00:00:00") + + # Update tool status + for item in self.tools_tree.get_children(): + self.tools_tree.set(item, "status", "Ready") + + def _on_server_error(self, error: str): + """Handle server error event.""" + self._log(f"Server error: {error}", "error") + self.status_canvas.itemconfig(self.status_indicator, fill="#ff4444") + self.status_label.config(text="Error") + + def _on_server_message(self, stream: str, message: str): + """Handle server message event.""" + if stream == "stderr": + if "error" in message.lower(): + self._log(message, "error") + elif "warning" in message.lower(): + self._log(message, "warning") + else: + self._log(message, "info") + else: + # Try to parse as JSON for structured logs + try: + data = json.loads(message) + if "level" in data: + self._log(data.get("message", message), data["level"].lower()) + else: + self._log(message, "info") + except: + self._log(message, "info") + + def _update_status(self): + """Update status display periodically.""" + status = self.server_manager.get_status() + + if status['running']: + # Update PID + self.pid_label.config(text=str(status['pid'])) + + # Update uptime + uptime = int(status['uptime']) + hours = uptime // 3600 + minutes = (uptime % 3600) // 60 + seconds = uptime % 60 + self.uptime_label.config(text=f"{hours:02d}:{minutes:02d}:{seconds:02d}") + + # Update port (from config) + port = self.server_config.get('server', {}).get('port', 'stdio') + self.port_label.config(text=str(port)) + else: + self.pid_label.config(text="N/A") + self.port_label.config(text="N/A") + + # Schedule next update + self.parent.after(1000, self._update_status) \ No newline at end of file diff --git a/model_converter.py b/model_converter.py new file mode 100644 index 0000000..2a4d40a --- /dev/null +++ b/model_converter.py @@ -0,0 +1,750 @@ +#!/usr/bin/env python3 +""" +Model Conversion and Editing Tools for DarkHal 2.0 + +Provides comprehensive model conversion between formats, quantization options, +and model editing capabilities. +""" + +import tkinter as tk +from tkinter import ttk, messagebox, filedialog, scrolledtext +import os +import sys +import json +import subprocess +import threading +import queue +import shutil +from pathlib import Path +from typing import Optional, Dict, Any, List, Tuple +from datetime import datetime +import tempfile + + +class ModelConverter: + """Handles model conversion operations.""" + + SUPPORTED_FORMATS = { + 'gguf': 'GGUF (llama.cpp)', + 'safetensors': 'SafeTensors (HuggingFace)', + 'bin': 'PyTorch Binary', + 'pt': 'PyTorch', + 'pth': 'PyTorch State Dict', + 'onnx': 'ONNX', + 'tflite': 'TensorFlow Lite', + 'h5': 'Keras/HDF5' + } + + QUANTIZATION_TYPES = { + 'q4_0': 'Q4_0 - 4-bit (smallest, lower quality)', + 'q4_1': 'Q4_1 - 4-bit (small, better than q4_0)', + 'q4_k_m': 'Q4_K_M - 4-bit (medium, recommended)', + 'q4_k_s': 'Q4_K_S - 4-bit (small)', + 'q5_0': 'Q5_0 - 5-bit', + 'q5_1': 'Q5_1 - 5-bit (better than q5_0)', + 'q5_k_m': 'Q5_K_M - 5-bit (medium, recommended)', + 'q5_k_s': 'Q5_K_S - 5-bit (small)', + 'q6_k': 'Q6_K - 6-bit (good quality/size ratio)', + 'q8_0': 'Q8_0 - 8-bit (high quality)', + 'f16': 'FP16 - 16-bit float', + 'f32': 'FP32 - 32-bit float (original)' + } + + def __init__(self): + self.conversion_queue = queue.Queue() + self.current_process = None + + def get_model_info(self, model_path: str) -> Dict[str, Any]: + """Get information about a model file.""" + path = Path(model_path) + if not path.exists(): + return None + + info = { + 'name': path.stem, + 'path': str(path), + 'format': path.suffix.lower().lstrip('.'), + 'size': path.stat().st_size, + 'size_mb': path.stat().st_size / (1024 * 1024), + 'modified': datetime.fromtimestamp(path.stat().st_mtime) + } + + # Try to extract additional metadata + if info['format'] == 'gguf': + info.update(self._get_gguf_info(path)) + elif info['format'] in ['safetensors', 'bin']: + info.update(self._get_hf_info(path)) + + return info + + def _get_gguf_info(self, path: Path) -> Dict[str, Any]: + """Extract GGUF model information.""" + info = {} + try: + # Try to use llama-cpp-python if available + from llama_cpp import Llama + # This would require actually loading the model which is expensive + # For now, extract from filename + name = path.stem.lower() + + # Detect quantization + for q_type in self.QUANTIZATION_TYPES.keys(): + if q_type in name: + info['quantization'] = q_type + break + + # Detect model size + import re + size_match = re.search(r'(\d+)b', name, re.IGNORECASE) + if size_match: + info['parameters'] = f"{size_match.group(1)}B" + + except Exception: + pass + + return info + + def _get_hf_info(self, path: Path) -> Dict[str, Any]: + """Extract HuggingFace model information.""" + info = {} + try: + # Look for config.json in the same directory + config_path = path.parent / "config.json" + if config_path.exists(): + with open(config_path, 'r') as f: + config = json.load(f) + info['model_type'] = config.get('model_type', 'unknown') + info['architectures'] = config.get('architectures', []) + info['vocab_size'] = config.get('vocab_size', 0) + + # Calculate parameters if possible + if 'hidden_size' in config and 'num_hidden_layers' in config: + hidden = config['hidden_size'] + layers = config['num_hidden_layers'] + vocab = config.get('vocab_size', 0) + # Rough parameter estimation + params = (hidden * hidden * 4 * layers + vocab * hidden) / 1e9 + info['parameters'] = f"{params:.1f}B" + except Exception: + pass + + return info + + def convert_to_gguf(self, input_path: str, output_path: str, + quantization: str = 'q4_k_m', + progress_callback: Optional[callable] = None) -> bool: + """Convert a model to GGUF format.""" + try: + input_format = Path(input_path).suffix.lower().lstrip('.') + + if input_format == 'gguf': + # Already GGUF, just quantize if needed + return self.quantize_gguf(input_path, output_path, quantization, progress_callback) + + # For HuggingFace models, use convert.py from llama.cpp + if input_format in ['safetensors', 'bin']: + return self._convert_hf_to_gguf(input_path, output_path, quantization, progress_callback) + + # For other formats, try generic conversion + return self._generic_convert(input_path, output_path, 'gguf', progress_callback) + + except Exception as e: + if progress_callback: + progress_callback(f"Error: {e}", 100, "error") + return False + + def quantize_gguf(self, input_path: str, output_path: str, + quantization: str = 'q4_k_m', + progress_callback: Optional[callable] = None) -> bool: + """Quantize a GGUF model.""" + try: + # Look for quantize executable + quantize_exe = self._find_quantize_executable() + if not quantize_exe: + if progress_callback: + progress_callback("Quantize executable not found", 100, "error") + return False + + # Run quantization + cmd = [str(quantize_exe), input_path, output_path, quantization] + + if progress_callback: + progress_callback("Starting quantization...", 0, "info") + + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1 + ) + + self.current_process = process + + # Monitor progress + while True: + line = process.stderr.readline() + if not line: + break + + if progress_callback: + # Parse progress from output + if "%" in line: + try: + import re + match = re.search(r'(\d+)%', line) + if match: + percent = int(match.group(1)) + progress_callback(line.strip(), percent, "info") + except: + progress_callback(line.strip(), -1, "info") + else: + progress_callback(line.strip(), -1, "info") + + process.wait() + + if process.returncode == 0: + if progress_callback: + progress_callback("Quantization complete!", 100, "success") + return True + else: + if progress_callback: + progress_callback("Quantization failed", 100, "error") + return False + + except Exception as e: + if progress_callback: + progress_callback(f"Error: {e}", 100, "error") + return False + finally: + self.current_process = None + + def _find_quantize_executable(self) -> Optional[Path]: + """Find the quantize executable.""" + # Common locations + locations = [ + Path("llama-cpp-python/vendor/llama.cpp/quantize"), + Path("llama.cpp/quantize"), + Path("bin/quantize"), + Path("quantize"), + Path("quantize.exe") + ] + + for loc in locations: + if loc.exists(): + return loc + + # Check in PATH + import shutil + exe = shutil.which("quantize") + if exe: + return Path(exe) + + return None + + def _convert_hf_to_gguf(self, input_path: str, output_path: str, + quantization: str, + progress_callback: Optional[callable] = None) -> bool: + """Convert HuggingFace model to GGUF.""" + try: + # Look for convert.py script + convert_script = self._find_convert_script() + if not convert_script: + if progress_callback: + progress_callback("Convert script not found", 100, "error") + return False + + # First convert to FP16 GGUF + temp_gguf = output_path.replace('.gguf', '_fp16.gguf') + + cmd = [ + sys.executable, + str(convert_script), + str(Path(input_path).parent), # Model directory + "--outfile", temp_gguf, + "--outtype", "f16" + ] + + if progress_callback: + progress_callback("Converting to GGUF format...", 0, "info") + + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + stdout, stderr = process.communicate() + + if process.returncode != 0: + if progress_callback: + progress_callback(f"Conversion failed: {stderr}", 100, "error") + return False + + if progress_callback: + progress_callback("Conversion complete, quantizing...", 50, "info") + + # Then quantize if needed + if quantization != 'f16': + result = self.quantize_gguf(temp_gguf, output_path, quantization, progress_callback) + # Clean up temp file + try: + os.remove(temp_gguf) + except: + pass + return result + else: + # Just rename + shutil.move(temp_gguf, output_path) + if progress_callback: + progress_callback("Conversion complete!", 100, "success") + return True + + except Exception as e: + if progress_callback: + progress_callback(f"Error: {e}", 100, "error") + return False + + def _find_convert_script(self) -> Optional[Path]: + """Find the convert.py script.""" + locations = [ + Path("llama-cpp-python/vendor/llama.cpp/convert.py"), + Path("llama.cpp/convert.py"), + Path("scripts/convert.py"), + Path("convert.py") + ] + + for loc in locations: + if loc.exists(): + return loc + + return None + + def _generic_convert(self, input_path: str, output_path: str, + target_format: str, + progress_callback: Optional[callable] = None) -> bool: + """Generic conversion using available tools.""" + # This would use tools like ONNX converters, TensorFlow converters, etc. + # For now, return False as not implemented + if progress_callback: + progress_callback(f"Conversion to {target_format} not yet implemented", 100, "error") + return False + + def merge_lora(self, base_model: str, lora_path: str, output_path: str, + progress_callback: Optional[callable] = None) -> bool: + """Merge a LoRA adapter into a base model.""" + try: + # This would use a LoRA merging tool + # For now, simplified implementation + if progress_callback: + progress_callback("LoRA merging not yet implemented", 100, "error") + return False + + except Exception as e: + if progress_callback: + progress_callback(f"Error: {e}", 100, "error") + return False + + def split_model(self, input_path: str, output_dir: str, + num_shards: int = 2, + progress_callback: Optional[callable] = None) -> bool: + """Split a model into multiple shards.""" + try: + # This would split large models for easier distribution + if progress_callback: + progress_callback("Model splitting not yet implemented", 100, "error") + return False + + except Exception as e: + if progress_callback: + progress_callback(f"Error: {e}", 100, "error") + return False + + +class ModelConverterTab: + """Model conversion and editing tab for DarkHal 2.0.""" + + def __init__(self, parent: ttk.Frame, settings_manager): + self.parent = parent + self.settings = settings_manager + self.converter = ModelConverter() + self.current_model = None + + self._build_ui() + + def _build_ui(self): + """Build the converter tab UI.""" + # Create paned window for split view + paned = ttk.PanedWindow(self.parent, orient=tk.HORIZONTAL) + paned.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Left panel - Model Selection and Info + left_frame = ttk.Frame(paned) + paned.add(left_frame, weight=1) + + # Model Selection Frame + select_frame = ttk.LabelFrame(left_frame, text="Model Selection", padding=10) + select_frame.pack(fill=tk.X, pady=(0, 10)) + + # Input model + ttk.Label(select_frame, text="Input Model:").pack(anchor=tk.W) + + input_frame = ttk.Frame(select_frame) + input_frame.pack(fill=tk.X, pady=(5, 10)) + + self.input_path_var = tk.StringVar() + self.input_entry = ttk.Entry(input_frame, textvariable=self.input_path_var) + self.input_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) + + ttk.Button(input_frame, text="Browse", command=self._browse_input).pack(side=tk.LEFT, padx=(5, 0)) + ttk.Button(input_frame, text="Analyze", command=self._analyze_model).pack(side=tk.LEFT, padx=(5, 0)) + + # Model Information Frame + info_frame = ttk.LabelFrame(left_frame, text="Model Information", padding=10) + info_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10)) + + self.info_text = scrolledtext.ScrolledText(info_frame, height=10, wrap=tk.WORD) + self.info_text.pack(fill=tk.BOTH, expand=True) + + # Conversion Options Frame + options_frame = ttk.LabelFrame(left_frame, text="Conversion Options", padding=10) + options_frame.pack(fill=tk.X) + + # Output format + ttk.Label(options_frame, text="Output Format:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.output_format_var = tk.StringVar(value="gguf") + format_combo = ttk.Combobox(options_frame, textvariable=self.output_format_var, + values=list(ModelConverter.SUPPORTED_FORMATS.keys()), + state="readonly", width=20) + format_combo.grid(row=0, column=1, sticky=tk.W, pady=5) + format_combo.bind('<>', self._on_format_change) + + # Quantization (for GGUF) + ttk.Label(options_frame, text="Quantization:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.quantization_var = tk.StringVar(value="q4_k_m") + self.quant_combo = ttk.Combobox(options_frame, textvariable=self.quantization_var, + values=list(ModelConverter.QUANTIZATION_TYPES.keys()), + state="readonly", width=20) + self.quant_combo.grid(row=1, column=1, sticky=tk.W, pady=5) + + # Output path + ttk.Label(options_frame, text="Output Path:").grid(row=2, column=0, sticky=tk.W, pady=5) + + output_frame = ttk.Frame(options_frame) + output_frame.grid(row=2, column=1, sticky=tk.W+tk.E, pady=5) + + self.output_path_var = tk.StringVar() + self.output_entry = ttk.Entry(output_frame, textvariable=self.output_path_var, width=30) + self.output_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) + + ttk.Button(output_frame, text="Browse", command=self._browse_output).pack(side=tk.LEFT, padx=(5, 0)) + + # Conversion button + self.convert_btn = ttk.Button(options_frame, text="Start Conversion", + command=self._start_conversion) + self.convert_btn.grid(row=3, column=0, columnspan=2, pady=(10, 0)) + + # Right panel - Advanced Tools and Progress + right_frame = ttk.Frame(paned) + paned.add(right_frame, weight=1) + + # Advanced Tools Frame + tools_frame = ttk.LabelFrame(right_frame, text="Advanced Tools", padding=10) + tools_frame.pack(fill=tk.X, pady=(0, 10)) + + # Tool buttons + tool_grid = ttk.Frame(tools_frame) + tool_grid.pack(fill=tk.X) + + ttk.Button(tool_grid, text="Merge LoRA", command=self._open_lora_merger).grid(row=0, column=0, padx=2, pady=2) + ttk.Button(tool_grid, text="Split Model", command=self._open_model_splitter).grid(row=0, column=1, padx=2, pady=2) + ttk.Button(tool_grid, text="Optimize", command=self._open_optimizer).grid(row=0, column=2, padx=2, pady=2) + + ttk.Button(tool_grid, text="Batch Convert", command=self._open_batch_converter).grid(row=1, column=0, padx=2, pady=2) + ttk.Button(tool_grid, text="Compare Models", command=self._open_model_compare).grid(row=1, column=1, padx=2, pady=2) + ttk.Button(tool_grid, text="Edit Metadata", command=self._open_metadata_editor).grid(row=1, column=2, padx=2, pady=2) + + # Quantization Comparison + compare_frame = ttk.LabelFrame(right_frame, text="Quantization Comparison", padding=10) + compare_frame.pack(fill=tk.X, pady=(0, 10)) + + # Comparison table + columns = ("Type", "Size", "Quality", "Speed") + self.compare_tree = ttk.Treeview(compare_frame, columns=columns, show="headings", height=6) + + self.compare_tree.heading("Type", text="Type") + self.compare_tree.heading("Size", text="Size") + self.compare_tree.heading("Quality", text="Quality") + self.compare_tree.heading("Speed", text="Speed") + + self.compare_tree.column("Type", width=80) + self.compare_tree.column("Size", width=80) + self.compare_tree.column("Quality", width=80) + self.compare_tree.column("Speed", width=80) + + self.compare_tree.pack(fill=tk.X) + + # Populate comparison table + self._populate_comparison() + + # Progress Frame + progress_frame = ttk.LabelFrame(right_frame, text="Conversion Progress", padding=10) + progress_frame.pack(fill=tk.BOTH, expand=True) + + # Progress bar + self.progress_var = tk.DoubleVar() + self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, + maximum=100, length=400) + self.progress_bar.pack(fill=tk.X, pady=(0, 10)) + + # Progress log + self.progress_text = scrolledtext.ScrolledText(progress_frame, height=10, wrap=tk.WORD, + bg="#1a1a1a", fg="#00ff88", + font=("Consolas", 9)) + self.progress_text.pack(fill=tk.BOTH, expand=True) + + # Control buttons + control_frame = ttk.Frame(progress_frame) + control_frame.pack(fill=tk.X, pady=(5, 0)) + + self.cancel_btn = ttk.Button(control_frame, text="Cancel", command=self._cancel_conversion, + state=tk.DISABLED) + self.cancel_btn.pack(side=tk.LEFT) + + ttk.Button(control_frame, text="Clear Log", command=self._clear_log).pack(side=tk.LEFT, padx=5) + + def _populate_comparison(self): + """Populate the quantization comparison table.""" + comparisons = [ + ("Q4_0", "~3.5GB", "★★☆☆☆", "★★★★★"), + ("Q4_K_M", "~3.8GB", "★★★☆☆", "★★★★★"), + ("Q5_K_M", "~4.5GB", "★★★★☆", "★★★★☆"), + ("Q6_K", "~5.5GB", "★★★★☆", "★★★☆☆"), + ("Q8_0", "~7GB", "★★★★★", "★★☆☆☆"), + ("FP16", "~14GB", "★★★★★", "★☆☆☆☆"), + ] + + for comp in comparisons: + self.compare_tree.insert("", tk.END, values=comp) + + def _browse_input(self): + """Browse for input model.""" + filename = filedialog.askopenfilename( + title="Select Model File", + filetypes=[ + ("All Models", "*.gguf;*.safetensors;*.bin;*.pt;*.pth;*.onnx"), + ("GGUF", "*.gguf"), + ("SafeTensors", "*.safetensors"), + ("PyTorch", "*.bin;*.pt;*.pth"), + ("ONNX", "*.onnx"), + ("All Files", "*.*") + ] + ) + + if filename: + self.input_path_var.set(filename) + self._analyze_model() + + # Auto-generate output path + input_path = Path(filename) + output_format = self.output_format_var.get() + quantization = self.quantization_var.get() + + output_name = f"{input_path.stem}_{quantization}.{output_format}" + output_path = input_path.parent / output_name + self.output_path_var.set(str(output_path)) + + def _browse_output(self): + """Browse for output location.""" + filename = filedialog.asksaveasfilename( + title="Save Converted Model As", + defaultextension=f".{self.output_format_var.get()}", + filetypes=[ + (f"{self.output_format_var.get().upper()}", f"*.{self.output_format_var.get()}"), + ("All Files", "*.*") + ] + ) + + if filename: + self.output_path_var.set(filename) + + def _analyze_model(self): + """Analyze the selected model.""" + model_path = self.input_path_var.get() + if not model_path or not os.path.exists(model_path): + messagebox.showerror("Error", "Please select a valid model file") + return + + self.info_text.delete(1.0, tk.END) + self.info_text.insert(tk.END, "Analyzing model...\n\n") + + # Get model info + info = self.converter.get_model_info(model_path) + if info: + self.current_model = info + + # Display info + self.info_text.insert(tk.END, f"Name: {info['name']}\n") + self.info_text.insert(tk.END, f"Format: {info['format'].upper()}\n") + self.info_text.insert(tk.END, f"Size: {info['size_mb']:.1f} MB\n") + self.info_text.insert(tk.END, f"Modified: {info['modified'].strftime('%Y-%m-%d %H:%M')}\n") + + if 'quantization' in info: + self.info_text.insert(tk.END, f"Quantization: {info['quantization']}\n") + if 'parameters' in info: + self.info_text.insert(tk.END, f"Parameters: {info['parameters']}\n") + if 'model_type' in info: + self.info_text.insert(tk.END, f"Model Type: {info['model_type']}\n") + if 'architectures' in info: + self.info_text.insert(tk.END, f"Architecture: {', '.join(info['architectures'])}\n") + if 'vocab_size' in info: + self.info_text.insert(tk.END, f"Vocab Size: {info['vocab_size']:,}\n") + + def _on_format_change(self, event=None): + """Handle output format change.""" + format_type = self.output_format_var.get() + + # Enable/disable quantization based on format + if format_type == 'gguf': + self.quant_combo.config(state="readonly") + else: + self.quant_combo.config(state="disabled") + + # Update output path + if self.input_path_var.get(): + input_path = Path(self.input_path_var.get()) + quantization = self.quantization_var.get() if format_type == 'gguf' else '' + + if quantization: + output_name = f"{input_path.stem}_{quantization}.{format_type}" + else: + output_name = f"{input_path.stem}.{format_type}" + + output_path = input_path.parent / output_name + self.output_path_var.set(str(output_path)) + + def _start_conversion(self): + """Start the conversion process.""" + input_path = self.input_path_var.get() + output_path = self.output_path_var.get() + + if not input_path or not os.path.exists(input_path): + messagebox.showerror("Error", "Please select a valid input model") + return + + if not output_path: + messagebox.showerror("Error", "Please specify an output path") + return + + # Confirm overwrite if exists + if os.path.exists(output_path): + if not messagebox.askyesno("Confirm", f"Output file exists. Overwrite?\n{output_path}"): + return + + # Disable UI + self.convert_btn.config(state=tk.DISABLED) + self.cancel_btn.config(state=tk.NORMAL) + + # Clear progress + self.progress_var.set(0) + self.progress_text.delete(1.0, tk.END) + + # Start conversion in thread + threading.Thread(target=self._conversion_thread, + args=(input_path, output_path), + daemon=True).start() + + def _conversion_thread(self, input_path: str, output_path: str): + """Run conversion in background thread.""" + def progress_callback(message: str, percent: int, level: str = "info"): + # Update UI in main thread + self.parent.after(0, self._update_progress, message, percent, level) + + try: + output_format = self.output_format_var.get() + + if output_format == 'gguf': + quantization = self.quantization_var.get() + success = self.converter.convert_to_gguf( + input_path, output_path, quantization, progress_callback + ) + else: + # Other format conversions + success = False + progress_callback(f"Conversion to {output_format} not yet implemented", 100, "error") + + if success: + self.parent.after(0, self._conversion_complete, True) + else: + self.parent.after(0, self._conversion_complete, False) + + except Exception as e: + progress_callback(f"Error: {e}", 100, "error") + self.parent.after(0, self._conversion_complete, False) + + def _update_progress(self, message: str, percent: int, level: str): + """Update progress display.""" + # Update progress bar + if percent >= 0: + self.progress_var.set(percent) + + # Add to log + timestamp = datetime.now().strftime("%H:%M:%S") + + # Color based on level + if level == "error": + self.progress_text.insert(tk.END, f"[{timestamp}] ERROR: {message}\n", "error") + elif level == "success": + self.progress_text.insert(tk.END, f"[{timestamp}] SUCCESS: {message}\n", "success") + else: + self.progress_text.insert(tk.END, f"[{timestamp}] {message}\n") + + self.progress_text.see(tk.END) + + def _conversion_complete(self, success: bool): + """Handle conversion completion.""" + self.convert_btn.config(state=tk.NORMAL) + self.cancel_btn.config(state=tk.DISABLED) + + if success: + messagebox.showinfo("Success", "Model conversion completed successfully!") + else: + messagebox.showerror("Error", "Model conversion failed. Check the log for details.") + + def _cancel_conversion(self): + """Cancel the current conversion.""" + if self.converter.current_process: + self.converter.current_process.terminate() + self._update_progress("Conversion cancelled", 0, "error") + self.convert_btn.config(state=tk.NORMAL) + self.cancel_btn.config(state=tk.DISABLED) + + def _clear_log(self): + """Clear the progress log.""" + self.progress_text.delete(1.0, tk.END) + self.progress_var.set(0) + + def _open_lora_merger(self): + """Open LoRA merger dialog.""" + messagebox.showinfo("LoRA Merger", "LoRA merging tool coming soon!") + + def _open_model_splitter(self): + """Open model splitter dialog.""" + messagebox.showinfo("Model Splitter", "Model splitting tool coming soon!") + + def _open_optimizer(self): + """Open model optimizer dialog.""" + messagebox.showinfo("Model Optimizer", "Model optimization tool coming soon!") + + def _open_batch_converter(self): + """Open batch converter dialog.""" + messagebox.showinfo("Batch Converter", "Batch conversion tool coming soon!") + + def _open_model_compare(self): + """Open model comparison dialog.""" + messagebox.showinfo("Model Compare", "Model comparison tool coming soon!") + + def _open_metadata_editor(self): + """Open metadata editor dialog.""" + messagebox.showinfo("Metadata Editor", "Metadata editing tool coming soon!") \ No newline at end of file diff --git a/model_library.py b/model_library.py new file mode 100644 index 0000000..87fb004 --- /dev/null +++ b/model_library.py @@ -0,0 +1,744 @@ +import os +import json +import time +import threading +import tkinter as tk +from tkinter import ttk, messagebox, filedialog +from typing import List, Dict, Any, Optional +from pathlib import Path +import hashlib +from dataclasses import dataclass, asdict +from datetime import datetime + + +@dataclass +class ModelInfo: + """Information about a model file.""" + name: str + path: str + size: int + modified_time: float + file_type: str + hash: str = "" + metadata: Dict[str, Any] = None + tags: List[str] = None + description: str = "" + + def __post_init__(self): + if self.metadata is None: + self.metadata = {} + if self.tags is None: + self.tags = [] + + @property + def size_mb(self) -> float: + return self.size / (1024 * 1024) + + @property + def modified_date(self) -> str: + return datetime.fromtimestamp(self.modified_time).strftime("%Y-%m-%d %H:%M") + + +class ModelLibrary: + """Manages a library of model files with scanning and indexing.""" + + def __init__(self, library_root: str, max_depth: int = 3): + self.library_root = Path(library_root) + self.max_depth = max_depth + self.models: Dict[str, ModelInfo] = {} + self.index_file = self.library_root / ".model_index.json" + self.supported_extensions = {'.gguf', '.bin', '.safetensors', '.pt', '.pth', '.onnx'} + self.scan_in_progress = False + + def scan_library(self, progress_callback: Optional[callable] = None) -> List[ModelInfo]: + """Scan the library directory for model files.""" + if self.scan_in_progress: + return list(self.models.values()) + + self.scan_in_progress = True + new_models = {} + total_files = 0 + processed_files = 0 + + try: + # Count total files first for progress + if progress_callback: + for root, dirs, files in os.walk(self.library_root): + depth = len(Path(root).relative_to(self.library_root).parts) + if depth >= self.max_depth: + dirs.clear() + total_files += len([f for f in files if Path(f).suffix.lower() in self.supported_extensions]) + + progress_callback(0, total_files, "Counting files...") + + # Scan for model files + for root, dirs, files in os.walk(self.library_root): + # Limit scan depth + depth = len(Path(root).relative_to(self.library_root).parts) + if depth >= self.max_depth: + dirs.clear() + continue + + for file in files: + file_path = Path(root) / file + + # Check if it's a supported model file + if file_path.suffix.lower() not in self.supported_extensions: + continue + + try: + stat = file_path.stat() + file_hash = self._calculate_file_hash(str(file_path)) + + # Check if file already exists in index + existing_model = self.models.get(str(file_path)) + if (existing_model and + existing_model.modified_time == stat.st_mtime and + existing_model.hash == file_hash): + # File unchanged, use existing data + new_models[str(file_path)] = existing_model + else: + # New or changed file + model_info = ModelInfo( + name=file_path.stem, + path=str(file_path), + size=stat.st_size, + modified_time=stat.st_mtime, + file_type=file_path.suffix.lower(), + hash=file_hash + ) + # Try to extract metadata + self._extract_metadata(model_info) + new_models[str(file_path)] = model_info + + processed_files += 1 + if progress_callback: + progress_callback(processed_files, total_files, f"Processing {file}") + + except Exception as e: + print(f"Error processing {file_path}: {e}") + continue + + self.models = new_models + self._save_index() + + finally: + self.scan_in_progress = False + + return list(self.models.values()) + + def _calculate_file_hash(self, file_path: str, chunk_size: int = 8192) -> str: + """Calculate MD5 hash of first and last chunks for quick identification.""" + try: + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + # Hash first chunk + chunk = f.read(chunk_size) + if chunk: + hash_md5.update(chunk) + + # Hash last chunk if file is large enough + file_size = os.path.getsize(file_path) + if file_size > chunk_size * 2: + f.seek(-chunk_size, 2) + chunk = f.read(chunk_size) + hash_md5.update(chunk) + + return hash_md5.hexdigest() + except Exception: + return "" + + def _extract_metadata(self, model_info: ModelInfo): + """Extract metadata from model files.""" + try: + file_path = Path(model_info.path) + + if model_info.file_type == '.gguf': + # Try to extract GGUF metadata + model_info.metadata = self._extract_gguf_metadata(file_path) + elif model_info.file_type in ['.bin', '.safetensors']: + # Try to extract HuggingFace metadata + model_info.metadata = self._extract_hf_metadata(file_path) + + # Extract tags from path + path_parts = file_path.parts + model_info.tags = [part.lower() for part in path_parts + if any(keyword in part.lower() + for keyword in ['q4', 'q5', 'q8', 'fp16', 'fp32', 'instruct', 'chat'])] + + except Exception as e: + print(f"Error extracting metadata from {model_info.path}: {e}") + + def _extract_gguf_metadata(self, file_path: Path) -> Dict[str, Any]: + """Extract metadata from GGUF files.""" + metadata = {} + try: + # This is a simplified version - you'd need proper GGUF parsing + # For now, extract info from filename + name = file_path.stem.lower() + + if 'q4' in name: + metadata['quantization'] = 'Q4' + elif 'q5' in name: + metadata['quantization'] = 'Q5' + elif 'q8' in name: + metadata['quantization'] = 'Q8' + elif 'fp16' in name: + metadata['quantization'] = 'FP16' + + if 'instruct' in name: + metadata['type'] = 'instruct' + elif 'chat' in name: + metadata['type'] = 'chat' + else: + metadata['type'] = 'base' + + # Extract model size if present + for part in name.split('-'): + if part.endswith('b') and part[:-1].replace('.', '').isdigit(): + metadata['parameters'] = part + break + + except Exception as e: + print(f"Error extracting GGUF metadata: {e}") + + return metadata + + def _extract_hf_metadata(self, file_path: Path) -> Dict[str, Any]: + """Extract metadata from HuggingFace model files.""" + metadata = {} + try: + # Look for config.json in the same directory + config_path = file_path.parent / "config.json" + if config_path.exists(): + with open(config_path, 'r') as f: + config = json.load(f) + metadata.update({ + 'architecture': config.get('architectures', []), + 'model_type': config.get('model_type', ''), + 'vocab_size': config.get('vocab_size', 0) + }) + except Exception as e: + print(f"Error extracting HF metadata: {e}") + + return metadata + + def _save_index(self): + """Save the model index to disk.""" + try: + index_data = { + 'last_scan': time.time(), + 'library_root': str(self.library_root), + 'max_depth': self.max_depth, + 'models': {path: asdict(model) for path, model in self.models.items()} + } + + with open(self.index_file, 'w') as f: + json.dump(index_data, f, indent=2) + + except Exception as e: + print(f"Error saving index: {e}") + + def _load_index(self): + """Load the model index from disk.""" + try: + if self.index_file.exists(): + with open(self.index_file, 'r') as f: + index_data = json.load(f) + + # Validate index is current + if (index_data.get('library_root') == str(self.library_root) and + index_data.get('max_depth') == self.max_depth): + + models_data = index_data.get('models', {}) + self.models = { + path: ModelInfo(**data) + for path, data in models_data.items() + } + + except Exception as e: + print(f"Error loading index: {e}") + + def search_models(self, query: str, file_type: str = "", tags: List[str] = None) -> List[ModelInfo]: + """Search models by name, type, or tags.""" + if tags is None: + tags = [] + + results = [] + query_lower = query.lower() + + for model in self.models.values(): + # Check name match + name_match = query_lower in model.name.lower() + + # Check file type match + type_match = not file_type or model.file_type == file_type + + # Check tags match + tags_match = not tags or any(tag in model.tags for tag in tags) + + if name_match and type_match and tags_match: + results.append(model) + + return sorted(results, key=lambda x: x.modified_time, reverse=True) + + def get_models_by_type(self) -> Dict[str, List[ModelInfo]]: + """Get models grouped by file type.""" + grouped = {} + for model in self.models.values(): + if model.file_type not in grouped: + grouped[model.file_type] = [] + grouped[model.file_type].append(model) + + return grouped + + def get_library_stats(self) -> Dict[str, Any]: + """Get library statistics.""" + if not self.models: + return {} + + total_size = sum(model.size for model in self.models.values()) + file_types = {} + + for model in self.models.values(): + if model.file_type not in file_types: + file_types[model.file_type] = {'count': 0, 'size': 0} + file_types[model.file_type]['count'] += 1 + file_types[model.file_type]['size'] += model.size + + return { + 'total_models': len(self.models), + 'total_size': total_size, + 'file_types': file_types, + 'last_scan': getattr(self, 'last_scan_time', 0) + } + + +class ModelLibraryTab: + """GUI tab for the Model Library.""" + + def __init__(self, parent: ttk.Frame, settings_manager): + self.parent = parent + self.settings = settings_manager + self.library = None + self.current_models = [] + + self._build_ui() + self._load_library() + + def _build_ui(self): + """Build the model library UI.""" + # Top toolbar + toolbar = ttk.Frame(self.parent) + toolbar.pack(fill=tk.X, padx=10, pady=5) + + # Left side controls + left_controls = ttk.Frame(toolbar) + left_controls.pack(side=tk.LEFT) + + ttk.Button(left_controls, text="Scan Library", + command=self._scan_library).pack(side=tk.LEFT, padx=2) + ttk.Button(left_controls, text="Refresh", + command=self._refresh).pack(side=tk.LEFT, padx=2) + ttk.Button(left_controls, text="Settings", + command=self._open_library_settings).pack(side=tk.LEFT, padx=2) + + # Search controls + search_frame = ttk.Frame(toolbar) + search_frame.pack(side=tk.RIGHT) + + ttk.Label(search_frame, text="Search:").pack(side=tk.LEFT, padx=2) + self.search_var = tk.StringVar() + self.search_entry = ttk.Entry(search_frame, textvariable=self.search_var, width=20) + self.search_entry.pack(side=tk.LEFT, padx=2) + self.search_entry.bind('', self._on_search) + + ttk.Label(search_frame, text="Type:").pack(side=tk.LEFT, padx=(10,2)) + self.type_var = tk.StringVar(value="All") + self.type_combo = ttk.Combobox(search_frame, textvariable=self.type_var, + values=["All", ".gguf", ".bin", ".safetensors", ".pt", ".onnx"], + state="readonly", width=10) + self.type_combo.pack(side=tk.LEFT, padx=2) + self.type_combo.bind('<>', self._on_filter_change) + + # Main content area + content_frame = ttk.Frame(self.parent) + content_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=5) + + # Left panel - Model list + left_panel = ttk.Frame(content_frame) + left_panel.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + + # Model treeview + columns = ("name", "type", "size", "modified", "tags") + self.tree = ttk.Treeview(left_panel, columns=columns, show="headings", height=20) + + self.tree.heading("name", text="Model Name") + self.tree.heading("type", text="Type") + self.tree.heading("size", text="Size") + self.tree.heading("modified", text="Modified") + self.tree.heading("tags", text="Tags") + + self.tree.column("name", width=300) + self.tree.column("type", width=80) + self.tree.column("size", width=100) + self.tree.column("modified", width=120) + self.tree.column("tags", width=150) + + # Scrollbars for treeview + v_scroll = ttk.Scrollbar(left_panel, orient="vertical", command=self.tree.yview) + h_scroll = ttk.Scrollbar(left_panel, orient="horizontal", command=self.tree.xview) + self.tree.configure(yscrollcommand=v_scroll.set, xscrollcommand=h_scroll.set) + + self.tree.grid(row=0, column=0, sticky="nsew") + v_scroll.grid(row=0, column=1, sticky="ns") + h_scroll.grid(row=1, column=0, sticky="ew") + + left_panel.grid_rowconfigure(0, weight=1) + left_panel.grid_columnconfigure(0, weight=1) + + # Right panel - Details and actions + right_panel = ttk.Frame(content_frame) + right_panel.pack(side=tk.RIGHT, fill=tk.Y, padx=(10, 0)) + + # Model details + details_frame = ttk.LabelFrame(right_panel, text="Model Details", padding=10) + details_frame.pack(fill=tk.BOTH, expand=True) + + self.details_text = tk.Text(details_frame, width=40, height=15, wrap=tk.WORD) + details_scroll = ttk.Scrollbar(details_frame, orient="vertical", command=self.details_text.yview) + self.details_text.configure(yscrollcommand=details_scroll.set) + + self.details_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + details_scroll.pack(side=tk.RIGHT, fill=tk.Y) + + # Action buttons + actions_frame = ttk.LabelFrame(right_panel, text="Actions", padding=10) + actions_frame.pack(fill=tk.X, pady=(10, 0)) + + ttk.Button(actions_frame, text="Load Model", + command=self._load_selected_model).pack(fill=tk.X, pady=2) + ttk.Button(actions_frame, text="Open Folder", + command=self._open_model_folder).pack(fill=tk.X, pady=2) + ttk.Button(actions_frame, text="Copy Path", + command=self._copy_model_path).pack(fill=tk.X, pady=2) + ttk.Button(actions_frame, text="Add Tags", + command=self._add_tags).pack(fill=tk.X, pady=2) + + # Status bar + self.status_label = ttk.Label(self.parent, text="Ready") + self.status_label.pack(fill=tk.X, padx=10, pady=2) + + # Bind events + self.tree.bind('<>', self._on_model_select) + self.tree.bind('', self._load_selected_model) + + def _load_library(self): + """Load the model library based on settings.""" + library_root = self.settings.get('library.root_folder', '') + if library_root and os.path.exists(library_root): + max_depth = self.settings.get('library.max_depth', 3) + self.library = ModelLibrary(library_root, max_depth) + self.library._load_index() + self._update_model_list() + else: + self.status_label.config(text="No library folder configured. Click Settings to configure.") + + def _scan_library(self): + """Scan the library for models.""" + if not self.library: + self._open_library_settings() + return + + # Show progress dialog + progress_dialog = tk.Toplevel(self.parent) + progress_dialog.title("Scanning Library") + progress_dialog.geometry("400x150") + progress_dialog.transient(self.parent.winfo_toplevel()) + progress_dialog.grab_set() + + ttk.Label(progress_dialog, text="Scanning model library...").pack(pady=10) + + progress_var = tk.DoubleVar() + progress_bar = ttk.Progressbar(progress_dialog, variable=progress_var, + maximum=100, length=300) + progress_bar.pack(pady=10) + + status_var = tk.StringVar(value="Starting scan...") + status_label = ttk.Label(progress_dialog, textvariable=status_var) + status_label.pack(pady=5) + + def progress_callback(current, total, message): + if total > 0: + progress_var.set((current / total) * 100) + status_var.set(message) + progress_dialog.update() + + def scan_thread(): + try: + models = self.library.scan_library(progress_callback) + progress_dialog.after(0, lambda: progress_dialog.destroy()) + self.parent.after(0, self._update_model_list) + self.parent.after(0, lambda: self.status_label.config( + text=f"Scan complete. Found {len(models)} models.")) + except Exception as e: + progress_dialog.after(0, lambda: progress_dialog.destroy()) + self.parent.after(0, lambda: messagebox.showerror("Scan Error", str(e))) + + threading.Thread(target=scan_thread, daemon=True).start() + + def _refresh(self): + """Refresh the model list.""" + if self.library: + self._update_model_list() + + def _update_model_list(self): + """Update the model list display.""" + if not self.library: + return + + # Clear existing items + for item in self.tree.get_children(): + self.tree.delete(item) + + # Apply filters + query = self.search_var.get() + file_type = self.type_var.get() if self.type_var.get() != "All" else "" + + if query or file_type: + models = self.library.search_models(query, file_type) + else: + models = list(self.library.models.values()) + + self.current_models = models + + # Populate treeview + for model in models: + tags_str = ", ".join(model.tags[:3]) # Show first 3 tags + self.tree.insert("", tk.END, values=( + model.name, + model.file_type, + f"{model.size_mb:.1f} MB", + model.modified_date, + tags_str + )) + + # Update status + stats = self.library.get_library_stats() + total_size_gb = stats.get('total_size', 0) / (1024**3) + self.status_label.config( + text=f"Showing {len(models)} of {stats.get('total_models', 0)} models " + f"({total_size_gb:.1f} GB total)" + ) + + def _on_search(self, event=None): + """Handle search input.""" + self._update_model_list() + + def _on_filter_change(self, event=None): + """Handle filter change.""" + self._update_model_list() + + def _on_model_select(self, event): + """Handle model selection.""" + selection = self.tree.selection() + if not selection: + return + + item = self.tree.item(selection[0]) + model_name = item['values'][0] + + # Find the selected model + selected_model = None + for model in self.current_models: + if model.name == model_name: + selected_model = model + break + + if selected_model: + self._show_model_details(selected_model) + + def _show_model_details(self, model: ModelInfo): + """Show detailed information about a model.""" + details = f"Name: {model.name}\\n" + details += f"Path: {model.path}\\n" + details += f"Type: {model.file_type}\\n" + details += f"Size: {model.size_mb:.1f} MB\\n" + details += f"Modified: {model.modified_date}\\n" + + if model.tags: + details += f"Tags: {', '.join(model.tags)}\\n" + + if model.metadata: + details += "\\nMetadata:\\n" + for key, value in model.metadata.items(): + details += f" {key}: {value}\\n" + + if model.description: + details += f"\\nDescription:\\n{model.description}" + + self.details_text.delete(1.0, tk.END) + self.details_text.insert(1.0, details) + + def _load_selected_model(self, event=None): + """Load the selected model in the main application.""" + selection = self.tree.selection() + if not selection: + messagebox.showinfo("No Selection", "Please select a model to load") + return + + item = self.tree.item(selection[0]) + model_name = item['values'][0] + + # Find the selected model + selected_model = None + for model in self.current_models: + if model.name == model_name: + selected_model = model + break + + if selected_model: + # Set the model path in the main application + parent_window = self.parent.winfo_toplevel() + if hasattr(parent_window, 'model_var'): + parent_window.model_var.set(selected_model.path) + messagebox.showinfo("Model Loaded", f"Loaded model: {selected_model.name}") + + def _open_model_folder(self): + """Open the folder containing the selected model.""" + selection = self.tree.selection() + if not selection: + return + + item = self.tree.item(selection[0]) + model_name = item['values'][0] + + selected_model = None + for model in self.current_models: + if model.name == model_name: + selected_model = model + break + + if selected_model: + folder = os.path.dirname(selected_model.path) + import subprocess + import platform + + try: + if platform.system() == "Windows": + subprocess.run(["explorer", folder]) + elif platform.system() == "Darwin": + subprocess.run(["open", folder]) + else: + subprocess.run(["xdg-open", folder]) + except Exception as e: + messagebox.showerror("Error", f"Could not open folder: {e}") + + def _copy_model_path(self): + """Copy the selected model's path to clipboard.""" + selection = self.tree.selection() + if not selection: + return + + item = self.tree.item(selection[0]) + model_name = item['values'][0] + + selected_model = None + for model in self.current_models: + if model.name == model_name: + selected_model = model + break + + if selected_model: + self.parent.clipboard_clear() + self.parent.clipboard_append(selected_model.path) + self.status_label.config(text="Model path copied to clipboard") + + def _add_tags(self): + """Add tags to the selected model.""" + selection = self.tree.selection() + if not selection: + return + + item = self.tree.item(selection[0]) + model_name = item['values'][0] + + selected_model = None + for model in self.current_models: + if model.name == model_name: + selected_model = model + break + + if selected_model: + import tkinter.simpledialog as simpledialog + + current_tags = ", ".join(selected_model.tags) + new_tags = simpledialog.askstring( + "Add Tags", + f"Enter tags (comma-separated):\\nCurrent: {current_tags}", + initialvalue=current_tags + ) + + if new_tags is not None: + selected_model.tags = [tag.strip() for tag in new_tags.split(",") if tag.strip()] + self.library._save_index() + self._update_model_list() + self._show_model_details(selected_model) + + def _open_library_settings(self): + """Open library settings dialog.""" + dialog = tk.Toplevel(self.parent) + dialog.title("Library Settings") + dialog.geometry("500x300") + dialog.transient(self.parent.winfo_toplevel()) + dialog.grab_set() + + # Root folder setting + ttk.Label(dialog, text="Library Root Folder:").pack(anchor=tk.W, padx=10, pady=5) + + folder_frame = ttk.Frame(dialog) + folder_frame.pack(fill=tk.X, padx=10, pady=5) + + folder_var = tk.StringVar(value=self.settings.get('library.root_folder', '')) + folder_entry = ttk.Entry(folder_frame, textvariable=folder_var, width=50) + folder_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) + + def browse_folder(): + folder = filedialog.askdirectory(initialdir=folder_var.get()) + if folder: + folder_var.set(folder) + + ttk.Button(folder_frame, text="Browse", command=browse_folder).pack(side=tk.LEFT, padx=5) + + # Max depth setting + ttk.Label(dialog, text="Maximum Scan Depth:").pack(anchor=tk.W, padx=10, pady=(10,5)) + + depth_var = tk.IntVar(value=self.settings.get('library.max_depth', 3)) + depth_frame = ttk.Frame(dialog) + depth_frame.pack(fill=tk.X, padx=10, pady=5) + + ttk.Scale(depth_frame, from_=1, to=10, variable=depth_var, + orient=tk.HORIZONTAL, length=200).pack(side=tk.LEFT) + ttk.Label(depth_frame, textvariable=depth_var).pack(side=tk.LEFT, padx=10) + ttk.Label(depth_frame, text="levels").pack(side=tk.LEFT) + + # Info label + info_text = ("Scan depth determines how many subdirectory levels to search.\\n" + "Higher values find more models but take longer to scan.") + ttk.Label(dialog, text=info_text, foreground="gray").pack(padx=10, pady=10) + + # Buttons + button_frame = ttk.Frame(dialog) + button_frame.pack(fill=tk.X, padx=10, pady=10) + + def save_settings(): + self.settings.set('library.root_folder', folder_var.get()) + self.settings.set('library.max_depth', depth_var.get()) + self.settings.save_settings() + + # Reload library + self._load_library() + dialog.destroy() + + ttk.Button(button_frame, text="Save", command=save_settings).pack(side=tk.RIGHT, padx=5) + ttk.Button(button_frame, text="Cancel", command=dialog.destroy).pack(side=tk.RIGHT) \ No newline at end of file diff --git a/pentestgpt.py b/pentestgpt.py new file mode 100644 index 0000000..9581040 --- /dev/null +++ b/pentestgpt.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +""" +PentestGPT Tab Implementation +Handles the PentestGPT interface and controls for DarkHal 2.0 +""" + +import tkinter as tk +from tkinter import ttk, messagebox, filedialog +from pathlib import Path +import subprocess +import threading + +class PentestGPTTab: + """PentestGPT tab with penetration testing agent controls.""" + + def __init__(self, parent: ttk.Frame, settings_manager, main_app=None): + self.parent = parent + self.settings = settings_manager + self.main_app = main_app + self.pentestgpt_process = None + + # Create main frame + self.main_frame = ttk.Frame(parent) + self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Create PentestGPT interface + self._create_pentestgpt_interface() + + def _create_pentestgpt_interface(self): + """Create PentestGPT control and configuration interface.""" + + # PentestGPT configuration frame + config_frame = ttk.LabelFrame(self.main_frame, text="PentestGPT Configuration", padding=10) + config_frame.pack(fill=tk.X, pady=(0, 10)) + + # Configuration options + options_frame = ttk.Frame(config_frame) + options_frame.pack(fill=tk.X, pady=10) + + # Model selection + ttk.Label(options_frame, text="Model:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.pentestgpt_model_var = tk.StringVar(value="gpt-4") + model_combo = ttk.Combobox(options_frame, textvariable=self.pentestgpt_model_var, + values=["gpt-4", "gpt-3.5-turbo", "claude-3", "local-llm"], width=15) + model_combo.grid(row=0, column=1, padx=10) + + # Target type + ttk.Label(options_frame, text="Target Type:").grid(row=0, column=2, sticky=tk.W, padx=(20, 0)) + self.target_type_var = tk.StringVar(value="web") + target_combo = ttk.Combobox(options_frame, textvariable=self.target_type_var, + values=["web", "network", "mobile", "cloud", "iot"], width=15) + target_combo.grid(row=0, column=3, padx=10) + + # Scan depth + ttk.Label(options_frame, text="Scan Depth:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.scan_depth_var = tk.StringVar(value="medium") + depth_combo = ttk.Combobox(options_frame, textvariable=self.scan_depth_var, + values=["light", "medium", "deep", "comprehensive"], width=15) + depth_combo.grid(row=1, column=1, padx=10) + + # Output directory + ttk.Label(options_frame, text="Output Dir:").grid(row=1, column=2, sticky=tk.W, padx=(20, 0)) + self.output_dir_var = tk.StringVar(value="./pentest_results") + ttk.Entry(options_frame, textvariable=self.output_dir_var, width=25).grid(row=1, column=3, padx=(10, 5)) + ttk.Button(options_frame, text="Browse", command=self._browse_output_dir).grid(row=1, column=4, padx=5) + + # API Configuration + api_frame = ttk.LabelFrame(config_frame, text="API Configuration", padding=5) + api_frame.pack(fill=tk.X, pady=(0, 10)) + + # OpenAI API Key + ttk.Label(api_frame, text="OpenAI API Key:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.openai_key_var = tk.StringVar() + ttk.Entry(api_frame, textvariable=self.openai_key_var, show="*", width=40).grid(row=0, column=1, padx=10) + + # Claude API Key + ttk.Label(api_frame, text="Claude API Key:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.claude_key_var = tk.StringVar() + ttk.Entry(api_frame, textvariable=self.claude_key_var, show="*", width=40).grid(row=1, column=1, padx=10) + + # Control buttons + control_frame = ttk.Frame(config_frame) + control_frame.pack(fill=tk.X, pady=10) + + self.pentestgpt_start_btn = ttk.Button(control_frame, text="Start PentestGPT", + command=self._start_pentestgpt) + self.pentestgpt_start_btn.pack(side=tk.LEFT, padx=5) + + self.pentestgpt_stop_btn = ttk.Button(control_frame, text="Stop PentestGPT", + command=self._stop_pentestgpt, state="disabled") + self.pentestgpt_stop_btn.pack(side=tk.LEFT, padx=5) + + # Configuration management + config_btn_frame = ttk.Frame(control_frame) + config_btn_frame.pack(side=tk.RIGHT, padx=5) + + ttk.Button(config_btn_frame, text="Save Config", + command=self._save_pentestgpt_config).pack(side=tk.LEFT, padx=2) + ttk.Button(config_btn_frame, text="Load Config", + command=self._load_pentestgpt_config).pack(side=tk.LEFT, padx=2) + ttk.Button(config_btn_frame, text="Configure", + command=self._configure_pentestgpt).pack(side=tk.LEFT, padx=2) + + # Output area + output_frame = ttk.LabelFrame(self.main_frame, text="PentestGPT Output", padding=10) + output_frame.pack(fill=tk.BOTH, expand=True, pady=10) + + self.pentestgpt_output = tk.Text(output_frame, height=20, wrap=tk.WORD) + pentestgpt_scrollbar = ttk.Scrollbar(output_frame, orient=tk.VERTICAL, command=self.pentestgpt_output.yview) + self.pentestgpt_output.configure(yscrollcommand=pentestgpt_scrollbar.set) + self.pentestgpt_output.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + pentestgpt_scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # Status bar + self.pentestgpt_status_var = tk.StringVar(value="PentestGPT Status: Ready") + ttk.Label(output_frame, textvariable=self.pentestgpt_status_var).pack(anchor=tk.W, pady=(5, 0)) + + def _start_pentestgpt(self): + """Start PentestGPT process.""" + self.pentestgpt_output.insert(tk.END, "Starting PentestGPT...\n") + self.pentestgpt_output.insert(tk.END, f"Model: {self.pentestgpt_model_var.get()}\n") + self.pentestgpt_output.insert(tk.END, f"Target Type: {self.target_type_var.get()}\n") + self.pentestgpt_output.insert(tk.END, f"Scan Depth: {self.scan_depth_var.get()}\n\n") + + # Check if PentestGPT is available + pentestgpt_path = Path("experimental_agent/PentestGPT") + if not pentestgpt_path.exists(): + self.pentestgpt_output.insert(tk.END, "Error: PentestGPT not found in experimental_agent directory\n") + self.pentestgpt_status_var.set("PentestGPT Status: Error - Not Found") + return + + # This would start the actual PentestGPT process + # For now, we'll simulate the startup + self.pentestgpt_output.insert(tk.END, "PentestGPT started successfully!\n") + self.pentestgpt_output.insert(tk.END, "Ready for penetration testing commands...\n\n") + + # Update button states + self.pentestgpt_start_btn.config(state="disabled") + self.pentestgpt_stop_btn.config(state="normal") + self.pentestgpt_status_var.set("PentestGPT Status: Running") + + def _stop_pentestgpt(self): + """Stop PentestGPT process.""" + if self.pentestgpt_process: + self.pentestgpt_process.terminate() + self.pentestgpt_process = None + + self.pentestgpt_output.insert(tk.END, "PentestGPT stopped.\n\n") + + # Update button states + self.pentestgpt_start_btn.config(state="normal") + self.pentestgpt_stop_btn.config(state="disabled") + self.pentestgpt_status_var.set("PentestGPT Status: Stopped") + + def _browse_output_dir(self): + """Browse for output directory.""" + directory = filedialog.askdirectory() + if directory: + self.output_dir_var.set(directory) + + def _save_pentestgpt_config(self): + """Save PentestGPT configuration.""" + config = { + "model": self.pentestgpt_model_var.get(), + "target_type": self.target_type_var.get(), + "scan_depth": self.scan_depth_var.get(), + "output_dir": self.output_dir_var.get(), + "openai_key": self.openai_key_var.get(), + "claude_key": self.claude_key_var.get() + } + + try: + config_file = Path("pentestgpt_config.json") + import json + with open(config_file, 'w') as f: + json.dump(config, f, indent=2) + self.pentestgpt_output.insert(tk.END, f"Configuration saved to {config_file}\n") + except Exception as e: + messagebox.showerror("Save Error", f"Failed to save configuration: {str(e)}") + + def _load_pentestgpt_config(self): + """Load PentestGPT configuration.""" + try: + config_file = Path("pentestgpt_config.json") + if not config_file.exists(): + messagebox.showwarning("Load Error", "No configuration file found") + return + + import json + with open(config_file, 'r') as f: + config = json.load(f) + + # Apply configuration + self.pentestgpt_model_var.set(config.get("model", "gpt-4")) + self.target_type_var.set(config.get("target_type", "web")) + self.scan_depth_var.set(config.get("scan_depth", "medium")) + self.output_dir_var.set(config.get("output_dir", "./pentest_results")) + self.openai_key_var.set(config.get("openai_key", "")) + self.claude_key_var.set(config.get("claude_key", "")) + + self.pentestgpt_output.insert(tk.END, f"Configuration loaded from {config_file}\n") + + except Exception as e: + messagebox.showerror("Load Error", f"Failed to load configuration: {str(e)}") + + def _configure_pentestgpt(self): + """Open PentestGPT configuration dialog.""" + PentestGPTConfigDialog(self.parent, self) + + +class PentestGPTConfigDialog: + """Dialog for configuring PentestGPT API keys and settings.""" + + def __init__(self, parent, pentestgpt_tab): + self.pentestgpt_tab = pentestgpt_tab + self.dialog = tk.Toplevel(parent) + self.dialog.title("PentestGPT Configuration") + self.dialog.geometry("500x400") + self.dialog.transient(parent) + self.dialog.grab_set() + + # Create configuration interface + self._create_config_interface() + + # Center dialog + self.dialog.update_idletasks() + x = (self.dialog.winfo_screenwidth() // 2) - (self.dialog.winfo_width() // 2) + y = (self.dialog.winfo_screenheight() // 2) - (self.dialog.winfo_height() // 2) + self.dialog.geometry(f"+{x}+{y}") + + def _create_config_interface(self): + """Create the configuration interface.""" + main_frame = ttk.Frame(self.dialog) + main_frame.pack(fill=tk.BOTH, expand=True, padx=20, pady=20) + + # API Keys section + api_frame = ttk.LabelFrame(main_frame, text="API Configuration", padding=10) + api_frame.pack(fill=tk.X, pady=(0, 20)) + + # OpenAI API Key + ttk.Label(api_frame, text="OpenAI API Key:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.openai_entry = ttk.Entry(api_frame, show="*", width=40) + self.openai_entry.grid(row=0, column=1, padx=(10, 0), pady=5) + self.openai_entry.insert(0, self.pentestgpt_tab.openai_key_var.get()) + + # Claude API Key + ttk.Label(api_frame, text="Claude API Key:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.claude_entry = ttk.Entry(api_frame, show="*", width=40) + self.claude_entry.grid(row=1, column=1, padx=(10, 0), pady=5) + self.claude_entry.insert(0, self.pentestgpt_tab.claude_key_var.get()) + + # Advanced Settings section + advanced_frame = ttk.LabelFrame(main_frame, text="Advanced Settings", padding=10) + advanced_frame.pack(fill=tk.X, pady=(0, 20)) + + # Max concurrent scans + ttk.Label(advanced_frame, text="Max Concurrent Scans:").grid(row=0, column=0, sticky=tk.W, pady=5) + self.max_scans_var = tk.StringVar(value="3") + ttk.Entry(advanced_frame, textvariable=self.max_scans_var, width=10).grid(row=0, column=1, padx=(10, 0), pady=5) + + # Timeout settings + ttk.Label(advanced_frame, text="Request Timeout (s):").grid(row=1, column=0, sticky=tk.W, pady=5) + self.timeout_var = tk.StringVar(value="30") + ttk.Entry(advanced_frame, textvariable=self.timeout_var, width=10).grid(row=1, column=1, padx=(10, 0), pady=5) + + # Buttons + button_frame = ttk.Frame(main_frame) + button_frame.pack(fill=tk.X, pady=(20, 0)) + + ttk.Button(button_frame, text="Save", command=self._save_config).pack(side=tk.RIGHT, padx=(10, 0)) + ttk.Button(button_frame, text="Cancel", command=self._cancel).pack(side=tk.RIGHT) + + def _save_config(self): + """Save the configuration.""" + try: + # Update the main tab variables + self.pentestgpt_tab.openai_key_var.set(self.openai_entry.get()) + self.pentestgpt_tab.claude_key_var.set(self.claude_entry.get()) + + # Save to file + config = { + "openai_key": self.openai_entry.get(), + "claude_key": self.claude_entry.get(), + "max_scans": self.max_scans_var.get(), + "timeout": self.timeout_var.get() + } + + import json + config_file = Path("pentestgpt_advanced_config.json") + with open(config_file, 'w') as f: + json.dump(config, f, indent=2) + + messagebox.showinfo("Success", "Configuration saved successfully!") + self.dialog.destroy() + + except Exception as e: + messagebox.showerror("Save Error", f"Failed to save configuration: {str(e)}") + + def _cancel(self): + """Cancel the configuration dialog.""" + self.dialog.destroy() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6864c72 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,117 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "darkhal" +version = "2.0.0" +description = "DarkHal 2.0 - Advanced Local LLM Interface with AI Tasks and Games" +readme = "README.md" +license = "LGPL-2.1-only" +authors = [ + {name = "Setec Labs", email = "contact@seteclabs.com"} +] +maintainers = [ + {name = "Setec Labs", email = "contact@seteclabs.com"} +] +keywords = ["llm", "ai", "chat", "local", "privacy", "chess", "automation"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: End Users/Desktop", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Communications :: Chat", + "Topic :: Games/Entertainment :: Board Games", + "Topic :: System :: Systems Administration :: Authentication/Directory" +] +requires-python = ">=3.12" +dependencies = [ + "llama-cpp-python>=0.2.80", + "huggingface_hub>=0.24.0", + "python-chess==1.999", + "requests>=2.32.0", + "numpy>=1.21.0", + "Pillow>=9.0.0", + "psutil>=5.8.0" +] + +[project.optional-dependencies] +gpu = [ + "torch>=2.0.0", + "transformers>=4.30.0" +] +dev = [ + "pytest>=7.0.0", + "black>=22.0.0", + "flake8>=5.0.0", + "mypy>=1.0.0" +] +all = [ + "darkhal[gpu,dev]" +] + +[project.urls] +Homepage = "https://github.com/seteclabs/darkhal" +Repository = "https://github.com/seteclabs/darkhal.git" +Documentation = "https://github.com/seteclabs/darkhal/wiki" +"Bug Tracker" = "https://github.com/seteclabs/darkhal/issues" + +[project.scripts] +darkhal = "darkhal.main:main" +darkhal-chess = "darkhal.chess_window:main" + +[project.gui-scripts] +darkhal-gui = "darkhal.main:main" + +[tool.setuptools] +package-dir = {"" = "."} + +[tool.setuptools.packages.find] +where = ["."] +include = ["darkhal*"] +exclude = ["tests*", "win-install*", "debian-install*", "llm_chess-main*", "temp_llm_chess*", "build*", "dist*"] + +[tool.setuptools.package-data] +darkhal = [ + "assets/*", + "engines/*", + "*.json", + "*.md", + "*.txt", + "*.env" +] + +[tool.black] +line-length = 88 +target-version = ['py312'] +include = '\.pyi?$' +extend-exclude = ''' +/( + # Directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' + +[tool.mypy] +python_version = "3.12" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] \ No newline at end of file diff --git a/remotecontrol.py b/remotecontrol.py new file mode 100644 index 0000000..9f5d3e6 --- /dev/null +++ b/remotecontrol.py @@ -0,0 +1,907 @@ +#!/usr/bin/env python3 +""" +LLM_Train Remote Control + +A standalone GUI application for remotely controlling the LLM_Train MCP server. +Allows users to connect to the server, load models, configure settings, and +perform inference operations remotely. +""" + +import asyncio +import json +import sys +import threading +import time +import tkinter as tk +from tkinter import ttk, messagebox, scrolledtext, filedialog +from typing import Dict, Any, List, Optional, Callable +import subprocess +import os +from pathlib import Path +import queue +from datetime import datetime + + +class MCPClient: + """Client for connecting to MCP server via subprocess.""" + + def __init__(self): + self.process = None + self.connected = False + self.request_id = 0 + self.callbacks: Dict[str, List[Callable]] = { + 'on_connect': [], + 'on_disconnect': [], + 'on_error': [], + 'on_response': [] + } + self.pending_requests: Dict[int, Callable] = {} + self.reader_thread = None + self.writer_queue = queue.Queue() + self.writer_thread = None + + def register_callback(self, event: str, callback: Callable): + """Register a callback for client events.""" + if event in self.callbacks: + self.callbacks[event].append(callback) + + def _trigger_callback(self, event: str, *args, **kwargs): + """Trigger callbacks for an event.""" + for callback in self.callbacks.get(event, []): + try: + callback(*args, **kwargs) + except Exception as e: + print(f"Callback error: {e}") + + async def connect(self, server_path: str = "mcp_server.py"): + """Connect to the MCP server.""" + try: + # Start the MCP server process + self.process = subprocess.Popen( + [sys.executable, server_path], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=0 + ) + + self.connected = True + + # Start reader and writer threads + self.reader_thread = threading.Thread(target=self._reader_loop, daemon=True) + self.writer_thread = threading.Thread(target=self._writer_loop, daemon=True) + + self.reader_thread.start() + self.writer_thread.start() + + # Send initialization request + init_request = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": { + "roots": { + "listChanged": True + }, + "sampling": {} + }, + "clientInfo": { + "name": "LLM_Train Remote Control", + "version": "1.0.0" + } + }, + "id": self._get_request_id() + } + + await self._send_request(init_request) + self._trigger_callback('on_connect') + + return True + + except Exception as e: + self._trigger_callback('on_error', f"Connection failed: {e}") + return False + + def disconnect(self): + """Disconnect from the MCP server.""" + self.connected = False + + if self.process: + try: + self.process.terminate() + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.process.kill() + except Exception: + pass + self.process = None + + self._trigger_callback('on_disconnect') + + def _get_request_id(self) -> int: + """Get next request ID.""" + self.request_id += 1 + return self.request_id + + async def _send_request(self, request: Dict[str, Any], callback: Optional[Callable] = None): + """Send a request to the server.""" + if not self.connected or not self.process: + return + + request_id = request.get('id') + if request_id and callback: + self.pending_requests[request_id] = callback + + # Queue the request for the writer thread + self.writer_queue.put(json.dumps(request) + '\n') + + def _writer_loop(self): + """Writer thread loop.""" + while self.connected and self.process: + try: + # Get request from queue + request_str = self.writer_queue.get(timeout=1) + + if self.process and self.process.stdin: + self.process.stdin.write(request_str) + self.process.stdin.flush() + + except queue.Empty: + continue + except Exception as e: + if self.connected: + self._trigger_callback('on_error', f"Write error: {e}") + break + + def _reader_loop(self): + """Reader thread loop.""" + while self.connected and self.process: + try: + if self.process and self.process.stdout: + line = self.process.stdout.readline() + if not line: + break + + line = line.strip() + if line: + try: + response = json.loads(line) + self._handle_response(response) + except json.JSONDecodeError as e: + self._trigger_callback('on_error', f"JSON decode error: {e}") + + except Exception as e: + if self.connected: + self._trigger_callback('on_error', f"Read error: {e}") + break + + # Connection lost + if self.connected: + self.connected = False + self._trigger_callback('on_disconnect') + + def _handle_response(self, response: Dict[str, Any]): + """Handle response from server.""" + request_id = response.get('id') + + if request_id and request_id in self.pending_requests: + callback = self.pending_requests.pop(request_id) + callback(response) + + self._trigger_callback('on_response', response) + + async def list_tools(self, callback: Optional[Callable] = None): + """List available tools on the server.""" + request = { + "jsonrpc": "2.0", + "method": "tools/list", + "id": self._get_request_id() + } + await self._send_request(request, callback) + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any], callback: Optional[Callable] = None): + """Call a tool on the server.""" + request = { + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments + }, + "id": self._get_request_id() + } + await self._send_request(request, callback) + + +class RemoteControlGUI: + """Main GUI for the remote control application.""" + + def __init__(self): + self.root = tk.Tk() + self.root.title("LLM_Train Remote Control") + self.root.geometry("1000x700") + self.root.minsize(800, 600) + + # Initialize MCP client + self.client = MCPClient() + self.client.register_callback('on_connect', self._on_connect) + self.client.register_callback('on_disconnect', self._on_disconnect) + self.client.register_callback('on_error', self._on_error) + + # UI state + self.connected = False + self.available_tools = [] + self.available_models = [] + self.current_model = None + self.system_info = {} + + # Setup UI + self._setup_ui() + + # Setup asyncio loop for MCP client + self.loop = asyncio.new_event_loop() + self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True) + self.loop_thread.start() + + def _run_event_loop(self): + """Run asyncio event loop in separate thread.""" + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def _setup_ui(self): + """Setup the main UI.""" + # Menu bar + self._create_menu() + + # Main container + main_frame = ttk.Frame(self.root) + main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Connection frame + self._create_connection_frame(main_frame) + + # Notebook for tabs + self.notebook = ttk.Notebook(main_frame) + self.notebook.pack(fill=tk.BOTH, expand=True, pady=(10, 0)) + + # Model Management tab + self._create_model_tab() + + # Inference tab + self._create_inference_tab() + + # System Info tab + self._create_system_tab() + + # Log tab + self._create_log_tab() + + # Status bar + self._create_status_bar(main_frame) + + def _create_menu(self): + """Create menu bar.""" + menubar = tk.Menu(self.root) + self.root.config(menu=menubar) + + # File menu + file_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="File", menu=file_menu) + file_menu.add_command(label="Connect to Server", command=self._connect_dialog) + file_menu.add_separator() + file_menu.add_command(label="Exit", command=self.root.quit) + + # Tools menu + tools_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="Tools", menu=tools_menu) + tools_menu.add_command(label="Refresh Models", command=self._refresh_models) + tools_menu.add_command(label="Get System Info", command=self._get_system_info) + + # Help menu + help_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="Help", menu=help_menu) + help_menu.add_command(label="About", command=self._show_about) + + def _create_connection_frame(self, parent): + """Create connection status frame.""" + conn_frame = ttk.LabelFrame(parent, text="Connection", padding=10) + conn_frame.pack(fill=tk.X, pady=(0, 10)) + + # Connection status + status_frame = ttk.Frame(conn_frame) + status_frame.pack(fill=tk.X) + + ttk.Label(status_frame, text="Status:").pack(side=tk.LEFT) + self.connection_status = ttk.Label(status_frame, text="Disconnected", foreground="red") + self.connection_status.pack(side=tk.LEFT, padx=(5, 20)) + + # Server path + ttk.Label(status_frame, text="Server:").pack(side=tk.LEFT) + self.server_path_var = tk.StringVar(value="mcp_server.py") + self.server_entry = ttk.Entry(status_frame, textvariable=self.server_path_var, width=30) + self.server_entry.pack(side=tk.LEFT, padx=5) + + # Browse button + ttk.Button(status_frame, text="Browse", command=self._browse_server).pack(side=tk.LEFT, padx=2) + + # Connect/Disconnect button + self.connect_btn = ttk.Button(status_frame, text="Connect", command=self._toggle_connection) + self.connect_btn.pack(side=tk.RIGHT, padx=5) + + def _create_model_tab(self): + """Create model management tab.""" + model_frame = ttk.Frame(self.notebook) + self.notebook.add(model_frame, text="Models") + + # Model list frame + list_frame = ttk.LabelFrame(model_frame, text="Available Models", padding=10) + list_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Model tree + columns = ("name", "type", "size", "status") + self.model_tree = ttk.Treeview(list_frame, columns=columns, show="headings", height=12) + + self.model_tree.heading("name", text="Model Name") + self.model_tree.heading("type", text="Type") + self.model_tree.heading("size", text="Size") + self.model_tree.heading("status", text="Status") + + self.model_tree.column("name", width=300) + self.model_tree.column("type", width=100) + self.model_tree.column("size", width=100) + self.model_tree.column("status", width=100) + + # Scrollbar + model_scroll = ttk.Scrollbar(list_frame, orient="vertical", command=self.model_tree.yview) + self.model_tree.configure(yscrollcommand=model_scroll.set) + + self.model_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + model_scroll.pack(side=tk.RIGHT, fill=tk.Y) + + # Model controls + control_frame = ttk.Frame(model_frame) + control_frame.pack(fill=tk.X, padx=10, pady=10) + + ttk.Button(control_frame, text="Refresh List", command=self._refresh_models).pack(side=tk.LEFT, padx=5) + ttk.Button(control_frame, text="Load Model", command=self._load_selected_model).pack(side=tk.LEFT, padx=5) + ttk.Button(control_frame, text="Unload Model", command=self._unload_model).pack(side=tk.LEFT, padx=5) + + # Current model info + current_frame = ttk.LabelFrame(model_frame, text="Current Model", padding=10) + current_frame.pack(fill=tk.X, padx=10, pady=10) + + self.current_model_label = ttk.Label(current_frame, text="No model loaded", font=("TkDefaultFont", 10, "bold")) + self.current_model_label.pack(anchor=tk.W) + + # Model configuration + config_frame = ttk.LabelFrame(model_frame, text="Model Configuration", padding=10) + config_frame.pack(fill=tk.X, padx=10, pady=10) + + # Context size + ctx_frame = ttk.Frame(config_frame) + ctx_frame.pack(fill=tk.X, pady=2) + ttk.Label(ctx_frame, text="Context Size:").pack(side=tk.LEFT) + self.ctx_var = tk.IntVar(value=4096) + ctx_spin = tk.Spinbox(ctx_frame, from_=512, to=32768, increment=512, textvariable=self.ctx_var, width=10) + ctx_spin.pack(side=tk.LEFT, padx=5) + + # GPU layers + gpu_frame = ttk.Frame(config_frame) + gpu_frame.pack(fill=tk.X, pady=2) + ttk.Label(gpu_frame, text="GPU Layers:").pack(side=tk.LEFT) + self.gpu_var = tk.IntVar(value=0) + gpu_spin = tk.Spinbox(gpu_frame, from_=0, to=100, increment=1, textvariable=self.gpu_var, width=10) + gpu_spin.pack(side=tk.LEFT, padx=5) + + def _create_inference_tab(self): + """Create inference tab.""" + inference_frame = ttk.Frame(self.notebook) + self.notebook.add(inference_frame, text="Inference") + + # Input frame + input_frame = ttk.LabelFrame(inference_frame, text="Input", padding=10) + input_frame.pack(fill=tk.X, padx=10, pady=10) + + # Prompt input + self.prompt_text = scrolledtext.ScrolledText(input_frame, height=6, wrap=tk.WORD) + self.prompt_text.pack(fill=tk.X, pady=5) + + # Generation controls + controls_frame = ttk.Frame(input_frame) + controls_frame.pack(fill=tk.X, pady=5) + + # Max tokens + ttk.Label(controls_frame, text="Max Tokens:").pack(side=tk.LEFT) + self.max_tokens_var = tk.IntVar(value=256) + max_tokens_spin = tk.Spinbox(controls_frame, from_=1, to=8192, increment=16, textvariable=self.max_tokens_var, width=10) + max_tokens_spin.pack(side=tk.LEFT, padx=5) + + # Temperature + ttk.Label(controls_frame, text="Temperature:").pack(side=tk.LEFT, padx=(20, 0)) + self.temperature_var = tk.DoubleVar(value=0.7) + temp_spin = tk.Spinbox(controls_frame, from_=0.0, to=2.0, increment=0.1, textvariable=self.temperature_var, width=10) + temp_spin.pack(side=tk.LEFT, padx=5) + + # Generate button + self.generate_btn = ttk.Button(controls_frame, text="Generate", command=self._generate_text) + self.generate_btn.pack(side=tk.RIGHT, padx=5) + + # Output frame + output_frame = ttk.LabelFrame(inference_frame, text="Output", padding=10) + output_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + self.output_text = scrolledtext.ScrolledText(output_frame, height=15, wrap=tk.WORD, state=tk.DISABLED) + self.output_text.pack(fill=tk.BOTH, expand=True) + + # Chat mode + chat_frame = ttk.LabelFrame(inference_frame, text="Chat Mode", padding=10) + chat_frame.pack(fill=tk.X, padx=10, pady=10) + + chat_controls = ttk.Frame(chat_frame) + chat_controls.pack(fill=tk.X) + + self.chat_mode_var = tk.BooleanVar() + ttk.Checkbutton(chat_controls, text="Enable Chat Mode", variable=self.chat_mode_var).pack(side=tk.LEFT) + + ttk.Button(chat_controls, text="Clear History", command=self._clear_chat).pack(side=tk.RIGHT) + + def _create_system_tab(self): + """Create system info tab.""" + system_frame = ttk.Frame(self.notebook) + self.notebook.add(system_frame, text="System") + + # System info display + info_frame = ttk.LabelFrame(system_frame, text="System Information", padding=10) + info_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + self.system_text = scrolledtext.ScrolledText(info_frame, height=20, wrap=tk.WORD, state=tk.DISABLED) + self.system_text.pack(fill=tk.BOTH, expand=True) + + # Refresh button + ttk.Button(system_frame, text="Refresh System Info", command=self._get_system_info).pack(pady=10) + + def _create_log_tab(self): + """Create log tab.""" + log_frame = ttk.Frame(self.notebook) + self.notebook.add(log_frame, text="Log") + + # Log display + self.log_text = scrolledtext.ScrolledText(log_frame, height=25, wrap=tk.WORD, state=tk.DISABLED) + self.log_text.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Log controls + log_controls = ttk.Frame(log_frame) + log_controls.pack(fill=tk.X, padx=10, pady=10) + + ttk.Button(log_controls, text="Clear Log", command=self._clear_log).pack(side=tk.LEFT) + ttk.Button(log_controls, text="Save Log", command=self._save_log).pack(side=tk.LEFT, padx=5) + + def _create_status_bar(self, parent): + """Create status bar.""" + self.status_bar = ttk.Label(parent, text="Ready", relief=tk.SUNKEN, anchor=tk.W) + self.status_bar.pack(fill=tk.X, side=tk.BOTTOM) + + def _log(self, message: str, level: str = "INFO"): + """Add message to log.""" + timestamp = datetime.now().strftime("%H:%M:%S") + log_entry = f"[{timestamp}] {level}: {message}\n" + + self.log_text.config(state=tk.NORMAL) + self.log_text.insert(tk.END, log_entry) + self.log_text.see(tk.END) + self.log_text.config(state=tk.DISABLED) + + def _set_status(self, message: str): + """Set status bar message.""" + self.status_bar.config(text=message) + + def _connect_dialog(self): + """Show connection dialog.""" + if self.connected: + self._disconnect() + else: + self._connect() + + def _browse_server(self): + """Browse for server script.""" + file_path = filedialog.askopenfilename( + title="Select MCP Server Script", + filetypes=[("Python files", "*.py"), ("All files", "*.*")] + ) + if file_path: + self.server_path_var.set(file_path) + + def _toggle_connection(self): + """Toggle connection to server.""" + if self.connected: + self._disconnect() + else: + self._connect() + + def _connect(self): + """Connect to MCP server.""" + server_path = self.server_path_var.get().strip() + if not server_path: + messagebox.showerror("Error", "Please specify server path") + return + + if not os.path.exists(server_path): + messagebox.showerror("Error", f"Server file not found: {server_path}") + return + + self._log(f"Connecting to server: {server_path}") + self._set_status("Connecting...") + + # Connect in asyncio thread + future = asyncio.run_coroutine_threadsafe(self.client.connect(server_path), self.loop) + + def check_result(): + if future.done(): + try: + success = future.result() + if success: + self._log("Connected successfully") + else: + self._log("Connection failed", "ERROR") + except Exception as e: + self._log(f"Connection error: {e}", "ERROR") + else: + self.root.after(100, check_result) + + check_result() + + def _disconnect(self): + """Disconnect from server.""" + self._log("Disconnecting from server") + self.client.disconnect() + + def _on_connect(self): + """Handle successful connection.""" + self.connected = True + self.connection_status.config(text="Connected", foreground="green") + self.connect_btn.config(text="Disconnect") + self._set_status("Connected to server") + + # Refresh data + self._refresh_models() + self._get_system_info() + + def _on_disconnect(self): + """Handle disconnection.""" + self.connected = False + self.connection_status.config(text="Disconnected", foreground="red") + self.connect_btn.config(text="Connect") + self._set_status("Disconnected from server") + + # Clear data + self.available_models = [] + self._update_model_list() + + def _on_error(self, error_message: str): + """Handle client errors.""" + self._log(error_message, "ERROR") + self._set_status(f"Error: {error_message}") + + def _refresh_models(self): + """Refresh the list of available models.""" + if not self.connected: + return + + self._log("Refreshing model list") + + def handle_response(response): + if 'error' in response: + self._log(f"Error listing models: {response['error']}", "ERROR") + return + + try: + result = response.get('result', []) + if result and isinstance(result, list) and len(result) > 0: + content = result[0].get('text', '[]') + self.available_models = json.loads(content) + else: + self.available_models = [] + + self.root.after(0, self._update_model_list) + self._log(f"Found {len(self.available_models)} models") + + except Exception as e: + self._log(f"Error parsing model list: {e}", "ERROR") + + future = asyncio.run_coroutine_threadsafe( + self.client.call_tool("list_models", {}, handle_response), + self.loop + ) + + def _update_model_list(self): + """Update the model list display.""" + # Clear existing items + for item in self.model_tree.get_children(): + self.model_tree.delete(item) + + # Add models + for model in self.available_models: + name = model.get('name', 'Unknown') + model_type = model.get('type', 'Unknown') + size_mb = model.get('size_mb', 0) + size_text = f"{size_mb:.1f} MB" if size_mb > 0 else "Unknown" + + self.model_tree.insert("", tk.END, values=(name, model_type, size_text, "Available")) + + def _load_selected_model(self): + """Load the selected model.""" + selection = self.model_tree.selection() + if not selection: + messagebox.showinfo("No Selection", "Please select a model to load") + return + + item = self.model_tree.item(selection[0]) + model_name = item['values'][0] + + # Find the model path + model_path = None + for model in self.available_models: + if model.get('name') == model_name: + model_path = model.get('path') + break + + if not model_path: + messagebox.showerror("Error", "Model path not found") + return + + self._log(f"Loading model: {model_name}") + + def handle_response(response): + if 'error' in response: + self._log(f"Error loading model: {response['error']}", "ERROR") + return + + try: + result = response.get('result', []) + if result and isinstance(result, list) and len(result) > 0: + content = result[0].get('text', '') + if 'Successfully loaded' in content: + self.current_model = model_name + self.root.after(0, lambda: self.current_model_label.config(text=f"Loaded: {model_name}")) + self._log(f"Model loaded successfully: {model_name}") + else: + self._log(f"Failed to load model: {content}", "ERROR") + + except Exception as e: + self._log(f"Error parsing load response: {e}", "ERROR") + + arguments = { + "model_path": model_path, + "n_ctx": self.ctx_var.get(), + "n_gpu_layers": self.gpu_var.get() + } + + future = asyncio.run_coroutine_threadsafe( + self.client.call_tool("load_model", arguments, handle_response), + self.loop + ) + + def _unload_model(self): + """Unload the current model.""" + if not self.current_model: + messagebox.showinfo("No Model", "No model is currently loaded") + return + + self._log("Unloading current model") + self.current_model = None + self.current_model_label.config(text="No model loaded") + + def _generate_text(self): + """Generate text using the current model.""" + if not self.connected: + messagebox.showerror("Error", "Not connected to server") + return + + if not self.current_model: + messagebox.showerror("Error", "No model loaded") + return + + prompt = self.prompt_text.get(1.0, tk.END).strip() + if not prompt: + messagebox.showwarning("Warning", "Please enter a prompt") + return + + self._log(f"Generating text for prompt: {prompt[:50]}...") + self._set_status("Generating...") + + def handle_response(response): + if 'error' in response: + self._log(f"Error generating text: {response['error']}", "ERROR") + return + + try: + result = response.get('result', []) + if result and isinstance(result, list) and len(result) > 0: + content_str = result[0].get('text', '{}') + content = json.loads(content_str) + + if 'error' in content: + self._log(f"Generation error: {content['error']}", "ERROR") + return + + generated_text = content.get('text', '') + + # Update output + self.root.after(0, lambda: self._update_output(prompt, generated_text)) + self._log("Text generation completed") + self.root.after(0, lambda: self._set_status("Generation completed")) + + except Exception as e: + self._log(f"Error parsing generation response: {e}", "ERROR") + + arguments = { + "prompt": prompt, + "max_tokens": self.max_tokens_var.get(), + "temperature": self.temperature_var.get() + } + + future = asyncio.run_coroutine_threadsafe( + self.client.call_tool("generate_text", arguments, handle_response), + self.loop + ) + + def _update_output(self, prompt: str, generated_text: str): + """Update the output text area.""" + self.output_text.config(state=tk.NORMAL) + + if self.chat_mode_var.get(): + # Chat mode - append to conversation + self.output_text.insert(tk.END, f"User: {prompt}\n\n") + self.output_text.insert(tk.END, f"Assistant: {generated_text}\n\n") + self.output_text.insert(tk.END, "-" * 50 + "\n\n") + else: + # Replace mode - show only current generation + self.output_text.delete(1.0, tk.END) + self.output_text.insert(tk.END, f"Prompt: {prompt}\n\n") + self.output_text.insert(tk.END, f"Response: {generated_text}") + + self.output_text.see(tk.END) + self.output_text.config(state=tk.DISABLED) + + def _clear_chat(self): + """Clear chat history.""" + self.output_text.config(state=tk.NORMAL) + self.output_text.delete(1.0, tk.END) + self.output_text.config(state=tk.DISABLED) + + def _get_system_info(self): + """Get system information from server.""" + if not self.connected: + return + + self._log("Getting system information") + + def handle_response(response): + if 'error' in response: + self._log(f"Error getting system info: {response['error']}", "ERROR") + return + + try: + result = response.get('result', []) + if result and isinstance(result, list) and len(result) > 0: + content_str = result[0].get('text', '{}') + self.system_info = json.loads(content_str) + + self.root.after(0, self._update_system_display) + self._log("System information updated") + + except Exception as e: + self._log(f"Error parsing system info: {e}", "ERROR") + + future = asyncio.run_coroutine_threadsafe( + self.client.call_tool("get_system_info", {}, handle_response), + self.loop + ) + + def _update_system_display(self): + """Update system information display.""" + self.system_text.config(state=tk.NORMAL) + self.system_text.delete(1.0, tk.END) + + # Format system info + info_text = "System Information\n" + info_text += "=" * 50 + "\n\n" + + info_text += f"Platform: {self.system_info.get('platform', 'Unknown')}\n" + info_text += f"Architecture: {self.system_info.get('architecture', 'Unknown')}\n\n" + + # Acceleration info + acceleration = self.system_info.get('acceleration', {}) + info_text += "Acceleration Support:\n" + info_text += f" CUDA Available: {acceleration.get('cuda_available', False)}\n" + if acceleration.get('cuda_available'): + info_text += f" CUDA Version: {acceleration.get('cuda_version', 'Unknown')}\n" + info_text += f" CUDA Devices: {acceleration.get('cuda_devices', 0)}\n" + + info_text += f" ROCm Available: {acceleration.get('rocm_available', False)}\n" + info_text += f" Metal Available: {acceleration.get('metal_available', False)}\n" + info_text += f" Intel GPU Available: {acceleration.get('intel_gpu_available', False)}\n" + info_text += f" Recommended GPU Layers: {acceleration.get('recommended_layers', 0)}\n\n" + + info_text += f"Current Model Acceleration: {self.system_info.get('current_model_acceleration', 'None')}\n" + + self.system_text.insert(tk.END, info_text) + self.system_text.config(state=tk.DISABLED) + + def _clear_log(self): + """Clear the log.""" + self.log_text.config(state=tk.NORMAL) + self.log_text.delete(1.0, tk.END) + self.log_text.config(state=tk.DISABLED) + + def _save_log(self): + """Save log to file.""" + file_path = filedialog.asksaveasfilename( + title="Save Log", + defaultextension=".txt", + filetypes=[("Text files", "*.txt"), ("All files", "*.*")] + ) + if file_path: + try: + with open(file_path, 'w') as f: + f.write(self.log_text.get(1.0, tk.END)) + self._log(f"Log saved to: {file_path}") + except Exception as e: + messagebox.showerror("Error", f"Failed to save log: {e}") + + def _show_about(self): + """Show about dialog.""" + about_text = """LLM_Train Remote Control v1.0.0 + +A remote control application for managing and +interacting with LLM_Train MCP servers. + +Features: +• Remote model loading and configuration +• Text generation and chat interface +• System information monitoring +• Connection management + +© 2024 LLM_Train Project""" + + messagebox.showinfo("About", about_text) + + def run(self): + """Run the application.""" + try: + self.root.protocol("WM_DELETE_WINDOW", self._on_closing) + self.root.mainloop() + except KeyboardInterrupt: + pass + finally: + self._cleanup() + + def _on_closing(self): + """Handle application closing.""" + if self.connected: + self.client.disconnect() + self._cleanup() + self.root.destroy() + + def _cleanup(self): + """Cleanup resources.""" + if hasattr(self, 'loop') and self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + + +def main(): + """Main entry point.""" + try: + app = RemoteControlGUI() + app.run() + except Exception as e: + print(f"Error starting application: {e}") + return 1 + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/settings_manager.py b/settings_manager.py new file mode 100644 index 0000000..65bcf9a --- /dev/null +++ b/settings_manager.py @@ -0,0 +1,867 @@ +import json +import os +from pathlib import Path +from typing import Dict, Any, Optional +import tkinter as tk +from tkinter import ttk, messagebox, filedialog + + +class SettingsManager: + """Manages application settings with JSON persistence.""" + + def __init__(self, settings_file: str = "settings.json"): + self.settings_file = settings_file + self.default_settings = { + "api": { + "huggingface_token": "", + "use_env_token": True, + "use_organization": False, + "organization": "" + }, + "paths": { + "models_directory": "./models", + "downloads_directory": "./downloads", + "last_model_path": "", + "last_lora_path": "" + }, + "model_settings": { + "default_n_ctx": 4096, + "default_n_gpu_layers": 0, + "default_max_tokens": 256, + "stream_by_default": True, + "temperature": 0.7, + "top_p": 0.9, + "repetition_penalty": 1.1, + "no_repeat_ngram_size": 0, + "min_p": 0.0, + "typical_p": 1.0 + }, + "search_preferences": { + "default_search_type": "Models", + "default_sort": "downloads", + "search_limit": 50, + "auto_filter_gguf": True + }, + "ui_preferences": { + "window_width": 1200, + "window_height": 700, + "theme": "default", + "show_tooltips": True + }, + "library": { + "root_folder": "", + "max_depth": 3, + "auto_scan_on_startup": False, + "watch_for_changes": False + }, + "download_settings": { + "max_concurrent_downloads": 3, + "max_download_speed": 0, + "min_download_speed": 0, + "retry_attempts": 3, + "timeout_seconds": 30 + } + } + self.settings = self.load_settings() + + def load_settings(self) -> Dict[str, Any]: + """Load settings from file or create default.""" + if os.path.exists(self.settings_file): + try: + with open(self.settings_file, 'r') as f: + loaded = json.load(f) + # Merge with defaults to handle new keys + return self._merge_settings(self.default_settings, loaded) + except Exception as e: + print(f"Error loading settings: {e}") + return self.default_settings.copy() + return self.default_settings.copy() + + def _merge_settings(self, defaults: Dict, loaded: Dict) -> Dict: + """Merge loaded settings with defaults, preserving user values.""" + result = defaults.copy() + for key, value in loaded.items(): + if key in result: + if isinstance(result[key], dict) and isinstance(value, dict): + result[key] = self._merge_settings(result[key], value) + else: + result[key] = value + else: + result[key] = value + return result + + def save_settings(self): + """Save current settings to file.""" + try: + with open(self.settings_file, 'w') as f: + json.dump(self.settings, f, indent=2) + return True + except Exception as e: + print(f"Error saving settings: {e}") + return False + + def get(self, path: str, default: Any = None) -> Any: + """Get a setting value using dot notation (e.g., 'api.huggingface_token').""" + keys = path.split('.') + value = self.settings + for key in keys: + if isinstance(value, dict) and key in value: + value = value[key] + else: + return default + return value + + def set(self, path: str, value: Any): + """Set a setting value using dot notation.""" + keys = path.split('.') + target = self.settings + for key in keys[:-1]: + if key not in target: + target[key] = {} + target = target[key] + target[keys[-1]] = value + + def reset_to_defaults(self): + """Reset all settings to defaults.""" + self.settings = self.default_settings.copy() + self.save_settings() + + +class SettingsDialog: + """Settings dialog window for the application.""" + + def __init__(self, parent: tk.Tk, settings_manager: SettingsManager): + self.parent = parent + self.settings = settings_manager + self.dialog = tk.Toplevel(parent) + self.dialog.title("Settings") + self.dialog.geometry("700x500") + self.dialog.resizable(True, True) + + # Make dialog modal + self.dialog.transient(parent) + self.dialog.grab_set() + + # Variables for settings + self.vars = {} + self._create_variables() + + # Build UI + self._build_ui() + + # Load current settings into UI + self._load_current_settings() + + # Center the dialog + self._center_window() + + def _create_variables(self): + """Create tkinter variables for settings.""" + self.vars = { + # API Settings + 'hf_token': tk.StringVar(), + 'use_env_token': tk.BooleanVar(), + 'use_organization': tk.BooleanVar(), + 'organization': tk.StringVar(), + + # Path Settings + 'models_dir': tk.StringVar(), + 'downloads_dir': tk.StringVar(), + + # Model Settings + 'default_n_ctx': tk.IntVar(), + 'default_n_gpu': tk.IntVar(), + 'default_max_tokens': tk.IntVar(), + 'stream_default': tk.BooleanVar(), + 'temperature': tk.DoubleVar(), + 'top_p': tk.DoubleVar(), + 'repetition_penalty': tk.DoubleVar(), + 'no_repeat_ngram_size': tk.IntVar(), + 'min_p': tk.DoubleVar(), + 'typical_p': tk.DoubleVar(), + + # Search Settings + 'search_type': tk.StringVar(), + 'default_sort': tk.StringVar(), + 'search_limit': tk.IntVar(), + 'auto_filter_gguf': tk.BooleanVar(), + + # UI Settings + 'show_tooltips': tk.BooleanVar(), + 'theme': tk.StringVar(), + + # Library Settings + 'library_root': tk.StringVar(), + 'library_depth': tk.IntVar(), + 'auto_scan': tk.BooleanVar(), + 'watch_changes': tk.BooleanVar(), + + # Download Settings + 'max_downloads': tk.IntVar(), + 'max_speed': tk.IntVar(), + 'min_speed': tk.IntVar(), + 'retry_attempts': tk.IntVar(), + 'timeout_seconds': tk.IntVar() + } + + def _build_ui(self): + """Build the settings dialog UI.""" + # Create notebook for tabs + notebook = ttk.Notebook(self.dialog) + notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # API Settings Tab + api_frame = ttk.Frame(notebook) + notebook.add(api_frame, text="API") + self._build_api_tab(api_frame) + + # Paths Tab + paths_frame = ttk.Frame(notebook) + notebook.add(paths_frame, text="Paths") + self._build_paths_tab(paths_frame) + + # Model Settings Tab + model_frame = ttk.Frame(notebook) + notebook.add(model_frame, text="Model Defaults") + self._build_model_tab(model_frame) + + # Search Settings Tab + search_frame = ttk.Frame(notebook) + notebook.add(search_frame, text="Search") + self._build_search_tab(search_frame) + + # UI Preferences Tab + ui_frame = ttk.Frame(notebook) + notebook.add(ui_frame, text="Interface") + self._build_ui_tab(ui_frame) + + # Library Settings Tab + library_frame = ttk.Frame(notebook) + notebook.add(library_frame, text="Library") + self._build_library_tab(library_frame) + + # Download Settings Tab + download_frame = ttk.Frame(notebook) + notebook.add(download_frame, text="Downloads") + self._build_download_tab(download_frame) + + # Buttons at bottom + button_frame = ttk.Frame(self.dialog) + button_frame.pack(fill=tk.X, padx=10, pady=(0, 10)) + + ttk.Button(button_frame, text="Save", command=self._save_settings).pack(side=tk.RIGHT, padx=5) + ttk.Button(button_frame, text="Cancel", command=self.dialog.destroy).pack(side=tk.RIGHT) + ttk.Button(button_frame, text="Reset to Defaults", command=self._reset_defaults).pack(side=tk.LEFT) + + def _build_api_tab(self, parent: ttk.Frame): + """Build API settings tab.""" + # Main container with scrollbar + canvas = tk.Canvas(parent) + scrollbar = ttk.Scrollbar(parent, orient="vertical", command=canvas.yview) + scrollable_frame = ttk.Frame(canvas) + + scrollable_frame.bind( + "", + lambda e: canvas.configure(scrollregion=canvas.bbox("all")) + ) + + canvas.create_window((0, 0), window=scrollable_frame, anchor="nw") + canvas.configure(yscrollcommand=scrollbar.set) + + # API Token Frame + token_frame = ttk.LabelFrame(scrollable_frame, text="API Token Management", padding=10) + token_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Use environment token checkbox + ttk.Checkbutton(token_frame, text="Use token from HUGGINGFACE.env file", + variable=self.vars['use_env_token'], + command=self._toggle_token_entry).grid(row=0, column=0, columnspan=3, sticky=tk.W, pady=5) + + # API Token entry + ttk.Label(token_frame, text="API Token:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.token_entry = ttk.Entry(token_frame, textvariable=self.vars['hf_token'], width=40, show="*") + self.token_entry.grid(row=1, column=1, sticky=tk.W, pady=5) + + # Token action buttons + token_buttons = ttk.Frame(token_frame) + token_buttons.grid(row=1, column=2, padx=5) + + self.show_token_btn = ttk.Button(token_buttons, text="View", width=8, + command=self._toggle_token_visibility) + self.show_token_btn.pack(side=tk.LEFT, padx=2) + + ttk.Button(token_buttons, text="Change", width=8, + command=self._change_token).pack(side=tk.LEFT, padx=2) + + ttk.Button(token_buttons, text="Test", width=8, + command=self._test_api_key).pack(side=tk.LEFT, padx=2) + + # Organization Frame + org_frame = ttk.LabelFrame(scrollable_frame, text="Organization Settings", padding=10) + org_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Use organization checkbox + ttk.Checkbutton(org_frame, text="Use as HuggingFace organization member", + variable=self.vars['use_organization'], + command=self._toggle_organization).grid(row=0, column=0, columnspan=3, sticky=tk.W, pady=5) + + # Organization dropdown + ttk.Label(org_frame, text="Organization:").grid(row=1, column=0, sticky=tk.W, pady=5) + self.org_combo = ttk.Combobox(org_frame, textvariable=self.vars['organization'], + state="disabled", width=30) + self.org_combo.grid(row=1, column=1, sticky=tk.W, pady=5) + + ttk.Button(org_frame, text="Fetch Organizations", + command=self._fetch_organizations).grid(row=1, column=2, padx=5) + + # Organizations list (for display) + self.org_listbox = tk.Listbox(org_frame, height=5, width=50) + self.org_listbox.grid(row=2, column=0, columnspan=3, pady=10) + self.org_listbox.bind('<>', self._on_org_select) + + # Info labels + info_frame = ttk.Frame(scrollable_frame) + info_frame.pack(fill=tk.X, padx=10, pady=10) + + info_text = ("• API tokens can be obtained from: https://huggingface.co/settings/tokens\n" + "• Organizations allow you to access private repos and team resources\n" + "• Test your API key to verify it's working correctly") + ttk.Label(info_frame, text=info_text, foreground="gray").pack(anchor=tk.W) + + # API Status label + self.api_status_label = ttk.Label(info_frame, text="", foreground="green") + self.api_status_label.pack(anchor=tk.W, pady=5) + + canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + def _build_paths_tab(self, parent: ttk.Frame): + """Build paths settings tab.""" + frame = ttk.LabelFrame(parent, text="Default Directories", padding=10) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Models directory + ttk.Label(frame, text="Models Directory:").grid(row=0, column=0, sticky=tk.W, pady=5) + ttk.Entry(frame, textvariable=self.vars['models_dir'], width=40).grid(row=0, column=1, pady=5) + ttk.Button(frame, text="Browse", + command=lambda: self._browse_directory('models_dir')).grid(row=0, column=2, padx=5) + + # Downloads directory + ttk.Label(frame, text="Downloads Directory:").grid(row=1, column=0, sticky=tk.W, pady=5) + ttk.Entry(frame, textvariable=self.vars['downloads_dir'], width=40).grid(row=1, column=1, pady=5) + ttk.Button(frame, text="Browse", + command=lambda: self._browse_directory('downloads_dir')).grid(row=1, column=2, padx=5) + + # Create directories checkbox + self.create_dirs_var = tk.BooleanVar(value=True) + ttk.Checkbutton(frame, text="Create directories if they don't exist", + variable=self.create_dirs_var).grid(row=2, column=0, columnspan=3, pady=10) + + def _build_model_tab(self, parent: ttk.Frame): + """Build model defaults tab.""" + frame = ttk.LabelFrame(parent, text="Default Model Settings", padding=10) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Context size + ttk.Label(frame, text="Default Context Size (n_ctx):").grid(row=0, column=0, sticky=tk.W, pady=5) + ctx_spin = tk.Spinbox(frame, from_=512, to=32768, increment=512, + textvariable=self.vars['default_n_ctx'], width=15) + ctx_spin.grid(row=0, column=1, sticky=tk.W, pady=5) + + # GPU layers + ttk.Label(frame, text="Default GPU Layers:").grid(row=1, column=0, sticky=tk.W, pady=5) + gpu_spin = tk.Spinbox(frame, from_=0, to=100, increment=1, + textvariable=self.vars['default_n_gpu'], width=15) + gpu_spin.grid(row=1, column=1, sticky=tk.W, pady=5) + + # Max tokens + ttk.Label(frame, text="Default Max Tokens:").grid(row=2, column=0, sticky=tk.W, pady=5) + tokens_spin = tk.Spinbox(frame, from_=16, to=8192, increment=16, + textvariable=self.vars['default_max_tokens'], width=15) + tokens_spin.grid(row=2, column=1, sticky=tk.W, pady=5) + + # Stream by default + ttk.Checkbutton(frame, text="Stream output by default", + variable=self.vars['stream_default']).grid(row=3, column=0, columnspan=2, pady=10) + + # Temperature + ttk.Label(frame, text="Temperature:").grid(row=4, column=0, sticky=tk.W, pady=5) + temp_spin = tk.Spinbox(frame, from_=0.0, to=2.0, increment=0.1, + textvariable=self.vars['temperature'], width=15, format="%.1f") + temp_spin.grid(row=4, column=1, sticky=tk.W, pady=5) + + # Top P + ttk.Label(frame, text="Top P:").grid(row=5, column=0, sticky=tk.W, pady=5) + top_p_spin = tk.Spinbox(frame, from_=0.0, to=1.0, increment=0.1, + textvariable=self.vars['top_p'], width=15, format="%.1f") + top_p_spin.grid(row=5, column=1, sticky=tk.W, pady=5) + + # Repetition Penalty + ttk.Label(frame, text="Repetition Penalty:").grid(row=6, column=0, sticky=tk.W, pady=5) + rep_pen_spin = tk.Spinbox(frame, from_=0.5, to=2.0, increment=0.1, + textvariable=self.vars['repetition_penalty'], width=15, format="%.1f") + rep_pen_spin.grid(row=6, column=1, sticky=tk.W, pady=5) + + # No Repeat N-gram Size + ttk.Label(frame, text="No Repeat N-gram Size:").grid(row=7, column=0, sticky=tk.W, pady=5) + ngram_spin = tk.Spinbox(frame, from_=0, to=10, increment=1, + textvariable=self.vars['no_repeat_ngram_size'], width=15) + ngram_spin.grid(row=7, column=1, sticky=tk.W, pady=5) + + # Min P + ttk.Label(frame, text="Min P:").grid(row=8, column=0, sticky=tk.W, pady=5) + min_p_spin = tk.Spinbox(frame, from_=0.0, to=1.0, increment=0.01, + textvariable=self.vars['min_p'], width=15, format="%.2f") + min_p_spin.grid(row=8, column=1, sticky=tk.W, pady=5) + + # Typical P + ttk.Label(frame, text="Typical P:").grid(row=9, column=0, sticky=tk.W, pady=5) + typical_p_spin = tk.Spinbox(frame, from_=0.0, to=1.0, increment=0.1, + textvariable=self.vars['typical_p'], width=15, format="%.1f") + typical_p_spin.grid(row=9, column=1, sticky=tk.W, pady=5) + + def _build_search_tab(self, parent: ttk.Frame): + """Build search settings tab.""" + frame = ttk.LabelFrame(parent, text="Search Preferences", padding=10) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Default search type + ttk.Label(frame, text="Default Search Type:").grid(row=0, column=0, sticky=tk.W, pady=5) + ttk.Combobox(frame, textvariable=self.vars['search_type'], + values=["Models", "Datasets"], state="readonly", width=20).grid(row=0, column=1, sticky=tk.W, pady=5) + + # Default sort + ttk.Label(frame, text="Default Sort By:").grid(row=1, column=0, sticky=tk.W, pady=5) + ttk.Combobox(frame, textvariable=self.vars['default_sort'], + values=["downloads", "likes", "lastModified"], state="readonly", width=20).grid(row=1, column=1, sticky=tk.W, pady=5) + + # Search limit + ttk.Label(frame, text="Results Limit:").grid(row=2, column=0, sticky=tk.W, pady=5) + limit_spin = tk.Spinbox(frame, from_=10, to=200, increment=10, + textvariable=self.vars['search_limit'], width=20) + limit_spin.grid(row=2, column=1, sticky=tk.W, pady=5) + + # Auto filter GGUF + ttk.Checkbutton(frame, text="Automatically filter for GGUF files when downloading models", + variable=self.vars['auto_filter_gguf']).grid(row=3, column=0, columnspan=2, pady=10) + + def _build_ui_tab(self, parent: ttk.Frame): + """Build UI preferences tab.""" + frame = ttk.LabelFrame(parent, text="Interface Settings", padding=10) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Theme selection + ttk.Label(frame, text="Theme:").grid(row=0, column=0, sticky=tk.W, pady=5) + ttk.Combobox(frame, textvariable=self.vars['theme'], + values=["default", "dark", "light"], state="readonly", width=20).grid(row=0, column=1, sticky=tk.W, pady=5) + + # Show tooltips + ttk.Checkbutton(frame, text="Show tooltips", + variable=self.vars['show_tooltips']).grid(row=1, column=0, columnspan=2, pady=10) + + # Note about themes + ttk.Label(frame, text="Note: Theme changes will take effect after restart", + foreground="gray").grid(row=2, column=0, columnspan=2, pady=5) + + def _build_library_tab(self, parent: ttk.Frame): + """Build library settings tab.""" + frame = ttk.LabelFrame(parent, text="Model Library Settings", padding=10) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Library root folder + ttk.Label(frame, text="Library Root Folder:").grid(row=0, column=0, sticky=tk.W, pady=5) + root_frame = ttk.Frame(frame) + root_frame.grid(row=0, column=1, columnspan=2, sticky=tk.W, pady=5) + + ttk.Entry(root_frame, textvariable=self.vars['library_root'], width=40).pack(side=tk.LEFT) + ttk.Button(root_frame, text="Browse", + command=lambda: self._browse_directory('library_root')).pack(side=tk.LEFT, padx=5) + + # Scan depth + ttk.Label(frame, text="Maximum Scan Depth:").grid(row=1, column=0, sticky=tk.W, pady=5) + depth_frame = ttk.Frame(frame) + depth_frame.grid(row=1, column=1, columnspan=2, sticky=tk.W, pady=5) + + depth_scale = ttk.Scale(depth_frame, from_=1, to=10, variable=self.vars['library_depth'], + orient=tk.HORIZONTAL, length=200) + depth_scale.pack(side=tk.LEFT) + + depth_label = ttk.Label(depth_frame, textvariable=self.vars['library_depth']) + depth_label.pack(side=tk.LEFT, padx=10) + ttk.Label(depth_frame, text="levels").pack(side=tk.LEFT) + + # Auto scan options + ttk.Checkbutton(frame, text="Auto-scan library on startup", + variable=self.vars['auto_scan']).grid(row=2, column=0, columnspan=3, sticky=tk.W, pady=5) + + ttk.Checkbutton(frame, text="Watch for file system changes (experimental)", + variable=self.vars['watch_changes']).grid(row=3, column=0, columnspan=3, sticky=tk.W, pady=5) + + # Info text + info_text = ("The library scanner searches for model files in the specified folder.\n" + "Scan depth controls how many subdirectory levels to search.\n" + "Supported formats: .gguf, .bin, .safetensors, .pt, .pth, .onnx") + ttk.Label(frame, text=info_text, foreground="gray").grid(row=4, column=0, columnspan=3, pady=10) + + def _build_download_tab(self, parent: ttk.Frame): + """Build download settings tab.""" + # Download limits frame + limits_frame = ttk.LabelFrame(parent, text="Download Limits", padding=10) + limits_frame.pack(fill=tk.X, padx=10, pady=10) + + # Max concurrent downloads + ttk.Label(limits_frame, text="Maximum Concurrent Downloads:").grid(row=0, column=0, sticky=tk.W, pady=5) + concurrent_spin = tk.Spinbox(limits_frame, from_=1, to=10, increment=1, + textvariable=self.vars['max_downloads'], width=15) + concurrent_spin.grid(row=0, column=1, sticky=tk.W, pady=5) + ttk.Label(limits_frame, text="(1-10 downloads)").grid(row=0, column=2, sticky=tk.W, padx=5) + + # Speed limits frame + speed_frame = ttk.LabelFrame(parent, text="Speed Limits (KB/s)", padding=10) + speed_frame.pack(fill=tk.X, padx=10, pady=10) + + # Max download speed + ttk.Label(speed_frame, text="Maximum Download Speed:").grid(row=0, column=0, sticky=tk.W, pady=5) + max_speed_spin = tk.Spinbox(speed_frame, from_=0, to=100000, increment=100, + textvariable=self.vars['max_speed'], width=15) + max_speed_spin.grid(row=0, column=1, sticky=tk.W, pady=5) + ttk.Label(speed_frame, text="(0 = unlimited)").grid(row=0, column=2, sticky=tk.W, padx=5) + + # Min download speed + ttk.Label(speed_frame, text="Minimum Download Speed:").grid(row=1, column=0, sticky=tk.W, pady=5) + min_speed_spin = tk.Spinbox(speed_frame, from_=0, to=10000, increment=10, + textvariable=self.vars['min_speed'], width=15) + min_speed_spin.grid(row=1, column=1, sticky=tk.W, pady=5) + ttk.Label(speed_frame, text="(0 = no minimum)").grid(row=1, column=2, sticky=tk.W, padx=5) + + # Connection settings frame + conn_frame = ttk.LabelFrame(parent, text="Connection Settings", padding=10) + conn_frame.pack(fill=tk.X, padx=10, pady=10) + + # Retry attempts + ttk.Label(conn_frame, text="Retry Attempts:").grid(row=0, column=0, sticky=tk.W, pady=5) + retry_spin = tk.Spinbox(conn_frame, from_=0, to=10, increment=1, + textvariable=self.vars['retry_attempts'], width=15) + retry_spin.grid(row=0, column=1, sticky=tk.W, pady=5) + ttk.Label(conn_frame, text="(number of retries on failure)").grid(row=0, column=2, sticky=tk.W, padx=5) + + # Timeout + ttk.Label(conn_frame, text="Connection Timeout:").grid(row=1, column=0, sticky=tk.W, pady=5) + timeout_spin = tk.Spinbox(conn_frame, from_=5, to=300, increment=5, + textvariable=self.vars['timeout_seconds'], width=15) + timeout_spin.grid(row=1, column=1, sticky=tk.W, pady=5) + ttk.Label(conn_frame, text="(seconds)").grid(row=1, column=2, sticky=tk.W, padx=5) + + # Info text + info_text = ("• Speed limits help manage bandwidth usage\n" + "• Concurrent downloads should be balanced with your internet connection\n" + "• Higher timeout values help with slow connections") + ttk.Label(parent, text=info_text, foreground="gray").pack(anchor=tk.W, padx=10, pady=10) + + def _toggle_token_entry(self): + """Enable/disable token entry based on checkbox.""" + if self.vars['use_env_token'].get(): + self.token_entry.config(state="disabled") + else: + self.token_entry.config(state="normal") + + def _toggle_token_visibility(self): + """Toggle token visibility.""" + if self.token_entry['show'] == "*": + self.token_entry.config(show="") + self.show_token_btn.config(text="Hide") + else: + self.token_entry.config(show="*") + self.show_token_btn.config(text="View") + + def _change_token(self): + """Open dialog to change API token.""" + import tkinter.simpledialog as simpledialog + + new_token = simpledialog.askstring( + "Change API Token", + "Enter new HuggingFace API token:", + parent=self.dialog, + show='*' + ) + + if new_token: + self.vars['hf_token'].set(new_token) + self.vars['use_env_token'].set(False) + self._toggle_token_entry() + self.api_status_label.config(text="Token updated. Click Test to verify.", foreground="blue") + + def _test_api_key(self): + """Test the API key.""" + import requests + + # Get the token to test + if self.vars['use_env_token'].get(): + import os + from dotenv import load_dotenv + load_dotenv("HUGGINGFACE.env") + token = os.getenv("HF_API_KEY") + else: + token = self.vars['hf_token'].get() + + if not token: + self.api_status_label.config(text="No API token configured", foreground="red") + return + + try: + # Test API by fetching user info + # Ensure token is properly stripped of whitespace and newlines + clean_token = token.strip().replace('\n', '').replace('\r', '') + headers = {"Authorization": f"Bearer {clean_token}"} + response = requests.get("https://huggingface.co/api/whoami", headers=headers) + + if response.status_code == 200: + user_data = response.json() + username = user_data.get('name', 'Unknown') + self.api_status_label.config( + text=f"✓ API key valid. Logged in as: {username}", + foreground="green" + ) + + # Update organizations if found + orgs = user_data.get('orgs', []) + if orgs: + org_names = [org.get('name', '') for org in orgs] + self.org_listbox.delete(0, tk.END) + for org in org_names: + self.org_listbox.insert(tk.END, org) + self.org_combo['values'] = org_names + + elif response.status_code == 401: + self.api_status_label.config(text="✗ Invalid API token", foreground="red") + else: + self.api_status_label.config( + text=f"✗ API test failed: {response.status_code}", + foreground="red" + ) + + except Exception as e: + self.api_status_label.config(text=f"✗ Connection error: {str(e)[:50]}", foreground="red") + + def _toggle_organization(self): + """Enable/disable organization controls.""" + if self.vars['use_organization'].get(): + self.org_combo.config(state="readonly") + if not self.org_combo['values']: + self._fetch_organizations() + else: + self.org_combo.config(state="disabled") + + def _fetch_organizations(self): + """Fetch organizations for the current API key.""" + import requests + + # Get the token + if self.vars['use_env_token'].get(): + import os + from dotenv import load_dotenv + load_dotenv("HUGGINGFACE.env") + token = os.getenv("HF_API_KEY") + else: + token = self.vars['hf_token'].get() + + if not token: + messagebox.showwarning("No Token", "Please configure an API token first") + return + + try: + # Ensure token is properly stripped of whitespace and newlines + clean_token = token.strip().replace('\n', '').replace('\r', '') + headers = {"Authorization": f"Bearer {clean_token}"} + response = requests.get("https://huggingface.co/api/whoami", headers=headers) + + if response.status_code == 200: + user_data = response.json() + orgs = user_data.get('orgs', []) + + if orgs: + org_names = [org.get('name', '') for org in orgs] + self.org_listbox.delete(0, tk.END) + for org in org_names: + self.org_listbox.insert(tk.END, org) + self.org_combo['values'] = org_names + + if org_names: + self.org_combo.set(org_names[0]) + self.api_status_label.config( + text=f"Found {len(org_names)} organization(s)", + foreground="green" + ) + else: + self.api_status_label.config( + text="No organizations found for this account", + foreground="blue" + ) + else: + self.api_status_label.config( + text=f"Failed to fetch organizations: {response.status_code}", + foreground="red" + ) + + except Exception as e: + messagebox.showerror("Error", f"Failed to fetch organizations: {str(e)}") + + def _on_org_select(self, event): + """Handle organization selection from listbox.""" + selection = self.org_listbox.curselection() + if selection: + org_name = self.org_listbox.get(selection[0]) + self.vars['organization'].set(org_name) + + def _browse_directory(self, var_name: str): + """Browse for directory.""" + directory = filedialog.askdirectory( + parent=self.dialog, + initialdir=self.vars[var_name].get() or "." + ) + if directory: + self.vars[var_name].set(directory) + + def _load_current_settings(self): + """Load current settings into UI variables.""" + self.vars['hf_token'].set(self.settings.get('api.huggingface_token', '')) + self.vars['use_env_token'].set(self.settings.get('api.use_env_token', True)) + self.vars['use_organization'].set(self.settings.get('api.use_organization', False)) + self.vars['organization'].set(self.settings.get('api.organization', '')) + + self.vars['models_dir'].set(self.settings.get('paths.models_directory', './models')) + self.vars['downloads_dir'].set(self.settings.get('paths.downloads_directory', './downloads')) + + self.vars['default_n_ctx'].set(self.settings.get('model_settings.default_n_ctx', 4096)) + self.vars['default_n_gpu'].set(self.settings.get('model_settings.default_n_gpu_layers', 0)) + self.vars['default_max_tokens'].set(self.settings.get('model_settings.default_max_tokens', 256)) + self.vars['stream_default'].set(self.settings.get('model_settings.stream_by_default', True)) + self.vars['temperature'].set(self.settings.get('model_settings.temperature', 0.7)) + self.vars['top_p'].set(self.settings.get('model_settings.top_p', 0.9)) + self.vars['repetition_penalty'].set(self.settings.get('model_settings.repetition_penalty', 1.1)) + self.vars['no_repeat_ngram_size'].set(self.settings.get('model_settings.no_repeat_ngram_size', 0)) + self.vars['min_p'].set(self.settings.get('model_settings.min_p', 0.0)) + self.vars['typical_p'].set(self.settings.get('model_settings.typical_p', 1.0)) + + self.vars['search_type'].set(self.settings.get('search_preferences.default_search_type', 'Models')) + self.vars['default_sort'].set(self.settings.get('search_preferences.default_sort', 'downloads')) + self.vars['search_limit'].set(self.settings.get('search_preferences.search_limit', 50)) + self.vars['auto_filter_gguf'].set(self.settings.get('search_preferences.auto_filter_gguf', True)) + + self.vars['show_tooltips'].set(self.settings.get('ui_preferences.show_tooltips', True)) + self.vars['theme'].set(self.settings.get('ui_preferences.theme', 'default')) + + # Library settings + self.vars['library_root'].set(self.settings.get('library.root_folder', '')) + self.vars['library_depth'].set(self.settings.get('library.max_depth', 3)) + self.vars['auto_scan'].set(self.settings.get('library.auto_scan_on_startup', False)) + self.vars['watch_changes'].set(self.settings.get('library.watch_for_changes', False)) + + # Download settings + self.vars['max_downloads'].set(self.settings.get('download_settings.max_concurrent_downloads', 3)) + self.vars['max_speed'].set(self.settings.get('download_settings.max_download_speed', 0)) + self.vars['min_speed'].set(self.settings.get('download_settings.min_download_speed', 0)) + self.vars['retry_attempts'].set(self.settings.get('download_settings.retry_attempts', 3)) + self.vars['timeout_seconds'].set(self.settings.get('download_settings.timeout_seconds', 30)) + + # Update token entry state + self._toggle_token_entry() + + def _save_settings(self): + """Save settings from UI to settings manager.""" + # API settings (strip whitespace from strings) + self.settings.set('api.huggingface_token', self.vars['hf_token'].get().strip()) + self.settings.set('api.use_env_token', self.vars['use_env_token'].get()) + self.settings.set('api.use_organization', self.vars['use_organization'].get()) + self.settings.set('api.organization', self.vars['organization'].get().strip()) + + # Path settings + models_dir = self.vars['models_dir'].get() + downloads_dir = self.vars['downloads_dir'].get() + + # Create directories if requested + if self.create_dirs_var.get(): + for directory in [models_dir, downloads_dir]: + if directory and not os.path.exists(directory): + try: + os.makedirs(directory, exist_ok=True) + except Exception as e: + messagebox.showerror("Error", f"Failed to create directory {directory}: {e}") + + self.settings.set('paths.models_directory', models_dir) + self.settings.set('paths.downloads_directory', downloads_dir) + + # Model settings + self.settings.set('model_settings.default_n_ctx', self.vars['default_n_ctx'].get()) + self.settings.set('model_settings.default_n_gpu_layers', self.vars['default_n_gpu'].get()) + self.settings.set('model_settings.default_max_tokens', self.vars['default_max_tokens'].get()) + self.settings.set('model_settings.stream_by_default', self.vars['stream_default'].get()) + self.settings.set('model_settings.temperature', self.vars['temperature'].get()) + self.settings.set('model_settings.top_p', self.vars['top_p'].get()) + self.settings.set('model_settings.repetition_penalty', self.vars['repetition_penalty'].get()) + self.settings.set('model_settings.no_repeat_ngram_size', self.vars['no_repeat_ngram_size'].get()) + self.settings.set('model_settings.min_p', self.vars['min_p'].get()) + self.settings.set('model_settings.typical_p', self.vars['typical_p'].get()) + + # Search settings + self.settings.set('search_preferences.default_search_type', self.vars['search_type'].get()) + self.settings.set('search_preferences.default_sort', self.vars['default_sort'].get()) + self.settings.set('search_preferences.search_limit', self.vars['search_limit'].get()) + self.settings.set('search_preferences.auto_filter_gguf', self.vars['auto_filter_gguf'].get()) + + # UI settings + self.settings.set('ui_preferences.show_tooltips', self.vars['show_tooltips'].get()) + self.settings.set('ui_preferences.theme', self.vars['theme'].get()) + + # Library settings + self.settings.set('library.root_folder', self.vars['library_root'].get().strip()) + self.settings.set('library.max_depth', self.vars['library_depth'].get()) + self.settings.set('library.auto_scan_on_startup', self.vars['auto_scan'].get()) + self.settings.set('library.watch_for_changes', self.vars['watch_changes'].get()) + + # Download settings + self.settings.set('download_settings.max_concurrent_downloads', self.vars['max_downloads'].get()) + self.settings.set('download_settings.max_download_speed', self.vars['max_speed'].get()) + self.settings.set('download_settings.min_download_speed', self.vars['min_speed'].get()) + self.settings.set('download_settings.retry_attempts', self.vars['retry_attempts'].get()) + self.settings.set('download_settings.timeout_seconds', self.vars['timeout_seconds'].get()) + + # Save to file + if self.settings.save_settings(): + messagebox.showinfo("Settings", "Settings saved successfully!") + self.dialog.destroy() + else: + messagebox.showerror("Error", "Failed to save settings") + + def _reset_defaults(self): + """Reset settings to defaults.""" + if messagebox.askyesno("Reset Settings", "Are you sure you want to reset all settings to defaults?"): + self.settings.reset_to_defaults() + self._load_current_settings() + messagebox.showinfo("Settings", "Settings reset to defaults") + + def _center_window(self): + """Center the dialog on the parent window.""" + self.dialog.update_idletasks() + + # Get parent position + parent_x = self.parent.winfo_x() + parent_y = self.parent.winfo_y() + parent_width = self.parent.winfo_width() + parent_height = self.parent.winfo_height() + + # Get dialog size + dialog_width = self.dialog.winfo_width() + dialog_height = self.dialog.winfo_height() + + # Calculate position + x = parent_x + (parent_width - dialog_width) // 2 + y = parent_y + (parent_height - dialog_height) // 2 + + self.dialog.geometry(f"+{x}+{y}") + + +def open_settings_dialog(parent: tk.Tk, settings_manager: SettingsManager): + """Convenience function to open the settings dialog.""" + dialog = SettingsDialog(parent, settings_manager) + return dialog \ No newline at end of file diff --git a/simple_agent_mode.py b/simple_agent_mode.py new file mode 100644 index 0000000..6483973 --- /dev/null +++ b/simple_agent_mode.py @@ -0,0 +1,584 @@ +#!/usr/bin/env python3 +""" +Simple Agent Mode - Direct system command execution through AI +""" + +import os +import sys +import subprocess +import shutil +import platform +import ctypes +import json +import tempfile +import asyncio +import psutil +import time +from typing import Optional, Dict, Any, List + +class SimpleAgentExecutor: + """Simple agent that can execute system commands based on AI responses""" + + def __init__(self, log_callback=None): + self.tools = self._register_tools() + self.log_callback = log_callback or print + self.active_processes = {} # Store PIDs of opened processes + + def _register_tools(self) -> Dict[str, callable]: + """Register available system tools""" + tools = {} + + # PowerShell execution + def powershell(command: str) -> str: + """Execute PowerShell command""" + try: + exe = shutil.which("pwsh") or shutil.which("powershell") + if not exe: + return "Error: PowerShell not found" + result = subprocess.run( + [exe, "-NoProfile", "-NonInteractive", "-ExecutionPolicy", "Bypass", "-Command", command], + capture_output=True, + text=True, + timeout=30 + ) + output = result.stdout + if result.stderr: + output += f"\n[stderr]: {result.stderr}" + return output.strip() or "Command executed successfully" + except Exception as e: + return f"Error: {e}" + + # Bash execution + def bash(command: str) -> str: + """Execute Bash command""" + try: + exe = shutil.which("bash") + if not exe: + return "Error: bash not found" + result = subprocess.run( + [exe, "-c", command], + capture_output=True, + text=True, + timeout=30 + ) + output = result.stdout + if result.stderr: + output += f"\n[stderr]: {result.stderr}" + return output.strip() or "Command executed successfully" + except Exception as e: + return f"Error: {e}" + + # Generic shell command + def shell(command: str) -> str: + """Execute shell command""" + try: + result = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + timeout=30 + ) + output = result.stdout + if result.stderr: + output += f"\n[stderr]: {result.stderr}" + return output.strip() or "Command executed successfully" + except Exception as e: + return f"Error: {e}" + + # Open applications with PID tracking + def open_app(app_name: str) -> str: + """Open an application and track its PID""" + try: + self.log_callback(f"Opening application: {app_name}") + + if platform.system() == "Windows": + # Try common Windows apps + apps = { + "notepad": "notepad.exe", + "word": "winword.exe", + "excel": "excel.exe", + "powerpoint": "powerpnt.exe", + "calculator": "calc.exe", + "paint": "mspaint.exe", + "cmd": "cmd.exe", + "powershell": "powershell.exe", + "explorer": "explorer.exe", + } + + app_path = apps.get(app_name.lower(), app_name) + process = subprocess.Popen([app_path], shell=True) + + # Wait a bit for process to start + time.sleep(1) + + # Find the actual PID of the opened window + try: + for proc in psutil.process_iter(['pid', 'name']): + if proc.info['name'] and app_path.lower() in proc.info['name'].lower(): + pid = proc.info['pid'] + self.active_processes[app_name.lower()] = pid + self.log_callback(f"Tracked process {app_name} with PID: {pid}") + return f"Opened {app_name} (PID: {pid})" + except: + pass + + return f"Opened {app_name}" + + elif platform.system() == "Darwin": # macOS + process = subprocess.Popen(["open", "-a", app_name]) + return f"Opened {app_name}" + + else: # Linux + process = subprocess.Popen([app_name], shell=True) + return f"Opened {app_name}" + + except Exception as e: + self.log_callback(f"Error opening {app_name}: {e}") + return f"Error opening {app_name}: {e}" + + # File operations + def read_file(filepath: str) -> str: + """Read a file""" + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + return f"File content:\n{content}" + except Exception as e: + return f"Error reading file: {e}" + + def write_file(filepath: str, content: str) -> str: + """Write to a file""" + try: + os.makedirs(os.path.dirname(filepath) or ".", exist_ok=True) + with open(filepath, 'w', encoding='utf-8') as f: + f.write(content) + return f"Wrote {len(content)} bytes to {filepath}" + except Exception as e: + return f"Error writing file: {e}" + + def list_files(directory: str = ".") -> str: + """List files in directory""" + try: + files = [] + for name in sorted(os.listdir(directory)): + path = os.path.join(directory, name) + if os.path.isdir(path): + files.append(f"[DIR] {name}/") + else: + files.append(f"[FILE] {name}") + return "\n".join(files) + except Exception as e: + return f"Error listing files: {e}" + + # Python code execution + def execute_python(code: str) -> str: + """Execute Python code""" + try: + # Create temp file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(code) + temp_file = f.name + + try: + result = subprocess.run( + [sys.executable, temp_file], + capture_output=True, + text=True, + timeout=30 + ) + output = result.stdout + if result.stderr: + output += f"\n[stderr]: {result.stderr}" + return output or "Code executed successfully" + finally: + os.unlink(temp_file) + + except Exception as e: + return f"Error executing Python: {e}" + + # Send commands to specific processes + def send_to_process(app_name: str, command: str) -> str: + """Send command to a specific tracked process""" + try: + self.log_callback(f"Sending command to {app_name}: {command}") + + if app_name.lower() not in self.active_processes: + return f"Process {app_name} not tracked. Available: {list(self.active_processes.keys())}" + + pid = self.active_processes[app_name.lower()] + + # Check if process is still running + try: + proc = psutil.Process(pid) + if not proc.is_running(): + del self.active_processes[app_name.lower()] + return f"Process {app_name} (PID: {pid}) is no longer running" + except psutil.NoSuchProcess: + del self.active_processes[app_name.lower()] + return f"Process {app_name} (PID: {pid}) not found" + + # For Windows, use multiple methods for better compatibility + if platform.system() == "Windows": + import win32gui + import win32con + import win32api + import win32process + + # Find window by PID + def enum_windows_callback(hwnd, results): + try: + _, found_pid = win32process.GetWindowThreadProcessId(hwnd) + if found_pid == pid and win32gui.IsWindowVisible(hwnd): + results.append(hwnd) + except: + pass + return True + + windows = [] + win32gui.EnumWindows(enum_windows_callback, windows) + + if windows: + hwnd = windows[0] + window_title = win32gui.GetWindowText(hwnd) + self.log_callback(f"Found window: {window_title} (HWND: {hwnd})") + + # Bring window to foreground + try: + win32gui.ShowWindow(hwnd, win32con.SW_RESTORE) + win32gui.SetForegroundWindow(hwnd) + time.sleep(0.5) + except Exception as e: + self.log_callback(f"Warning: Could not bring window to foreground: {e}") + + # Method 1: Try using SendKeys via win32api (more reliable for terminal apps) + try: + # Ensure the window has focus + win32gui.SetActiveWindow(hwnd) + time.sleep(0.2) + + # For PowerShell/CMD, try different approaches + if "powershell" in app_name.lower() or "cmd" in app_name.lower(): + # Method A: Use keybd_event for virtual key codes + import win32api + for char in command: + if char.isalnum() or char in " .-_/\\:": + vk_code = win32api.VkKeyScan(char) & 0xFF + win32api.keybd_event(vk_code, 0, 0, 0) # Key down + time.sleep(0.01) + win32api.keybd_event(vk_code, 0, win32con.KEYEVENTF_KEYUP, 0) # Key up + time.sleep(0.01) + + # Send Enter key + win32api.keybd_event(win32con.VK_RETURN, 0, 0, 0) + time.sleep(0.01) + win32api.keybd_event(win32con.VK_RETURN, 0, win32con.KEYEVENTF_KEYUP, 0) + + self.log_callback(f"Sent command using keyboard events to {app_name}") + return f"Command sent to {app_name} using keyboard simulation" + + else: + # Method B: Use WM_CHAR for other applications + for char in command: + win32gui.SendMessage(hwnd, win32con.WM_CHAR, ord(char), 0) + time.sleep(0.01) + + # Send Enter + win32gui.SendMessage(hwnd, win32con.WM_CHAR, 13, 0) + + self.log_callback(f"Sent command using WM_CHAR to {app_name}") + return f"Command sent to {app_name} using message posting" + + except Exception as e: + self.log_callback(f"Error in command sending methods: {e}") + + # Method 2: Fallback - try clipboard method for complex commands + try: + import win32clipboard + + # Copy command to clipboard + win32clipboard.OpenClipboard() + win32clipboard.EmptyClipboard() + win32clipboard.SetClipboardText(command) + win32clipboard.CloseClipboard() + + # Send Ctrl+V to paste + win32api.keybd_event(win32con.VK_CONTROL, 0, 0, 0) + win32api.keybd_event(ord('V'), 0, 0, 0) + time.sleep(0.05) + win32api.keybd_event(ord('V'), 0, win32con.KEYEVENTF_KEYUP, 0) + win32api.keybd_event(win32con.VK_CONTROL, 0, win32con.KEYEVENTF_KEYUP, 0) + + # Send Enter + time.sleep(0.1) + win32api.keybd_event(win32con.VK_RETURN, 0, 0, 0) + time.sleep(0.01) + win32api.keybd_event(win32con.VK_RETURN, 0, win32con.KEYEVENTF_KEYUP, 0) + + self.log_callback(f"Sent command using clipboard method to {app_name}") + return f"Command sent to {app_name} using clipboard paste" + + except Exception as e2: + self.log_callback(f"Clipboard method also failed: {e2}") + return f"Failed to send command to {app_name}: {e}" + + else: + # Try to get window processes for fallback selection + window_list = self.get_window_processes() + return f"No window found for {app_name} (PID: {pid}). Available windows:\n{window_list}" + + return f"Command sending not implemented for {platform.system()}" + + except Exception as e: + self.log_callback(f"Error sending command to {app_name}: {e}") + return f"Error sending command to {app_name}: {e}" + + def list_processes() -> str: + """List available processes to send commands to""" + try: + self.log_callback("Listing available processes") + + if not self.active_processes: + return "No tracked processes. Open an application first." + + result = "Available tracked processes:\n" + for app_name, pid in self.active_processes.items(): + try: + proc = psutil.Process(pid) + if proc.is_running(): + result += f"- {app_name} (PID: {pid}) - Running\n" + else: + result += f"- {app_name} (PID: {pid}) - Not running\n" + except psutil.NoSuchProcess: + result += f"- {app_name} (PID: {pid}) - Process not found\n" + + return result + + except Exception as e: + self.log_callback(f"Error listing processes: {e}") + return f"Error listing processes: {e}" + + def get_window_processes() -> str: + """Get list of visible window processes for fallback selection""" + try: + self.log_callback("Getting visible window processes") + + if platform.system() == "Windows": + import win32gui + import win32process + + def enum_windows_callback(hwnd, results): + if win32gui.IsWindowVisible(hwnd) and win32gui.GetWindowText(hwnd): + _, pid = win32process.GetWindowThreadProcessId(hwnd) + window_title = win32gui.GetWindowText(hwnd) + try: + proc = psutil.Process(pid) + results.append(f"PID: {pid} | {proc.name()} | {window_title}") + except: + results.append(f"PID: {pid} | Unknown | {window_title}") + return True + + windows = [] + win32gui.EnumWindows(enum_windows_callback, windows) + + if windows: + return "Visible window processes:\n" + "\n".join(windows[:20]) # Limit to 20 + else: + return "No visible windows found" + + return f"Window enumeration not implemented for {platform.system()}" + + except Exception as e: + self.log_callback(f"Error getting window processes: {e}") + return f"Error getting window processes: {e}" + + def send_to_pid(pid: int, command: str) -> str: + """Send command directly to a process by PID (fallback method)""" + try: + self.log_callback(f"Sending command to PID {pid}: {command}") + + # Check if process exists + try: + proc = psutil.Process(pid) + if not proc.is_running(): + return f"Process PID {pid} is not running" + app_name = proc.name() + except psutil.NoSuchProcess: + return f"Process PID {pid} not found" + + # Use the same logic as send_to_process but with PID directly + if platform.system() == "Windows": + import win32gui + import win32con + import win32api + import win32process + + # Find window by PID + def enum_windows_callback(hwnd, results): + try: + _, found_pid = win32process.GetWindowThreadProcessId(hwnd) + if found_pid == pid and win32gui.IsWindowVisible(hwnd): + results.append(hwnd) + except: + pass + return True + + windows = [] + win32gui.EnumWindows(enum_windows_callback, windows) + + if windows: + hwnd = windows[0] + window_title = win32gui.GetWindowText(hwnd) + self.log_callback(f"Found window: {window_title} (HWND: {hwnd})") + + # Bring window to foreground + try: + win32gui.ShowWindow(hwnd, win32con.SW_RESTORE) + win32gui.SetForegroundWindow(hwnd) + time.sleep(0.5) + except Exception as e: + self.log_callback(f"Warning: Could not bring window to foreground: {e}") + + # Try keyboard simulation for terminal apps + try: + win32gui.SetActiveWindow(hwnd) + time.sleep(0.2) + + # Use clipboard method for reliability + try: + import win32clipboard + + # Copy command to clipboard + win32clipboard.OpenClipboard() + win32clipboard.EmptyClipboard() + win32clipboard.SetClipboardText(command) + win32clipboard.CloseClipboard() + + # Send Ctrl+V to paste + win32api.keybd_event(win32con.VK_CONTROL, 0, 0, 0) + win32api.keybd_event(ord('V'), 0, 0, 0) + time.sleep(0.05) + win32api.keybd_event(ord('V'), 0, win32con.KEYEVENTF_KEYUP, 0) + win32api.keybd_event(win32con.VK_CONTROL, 0, win32con.KEYEVENTF_KEYUP, 0) + + # Send Enter + time.sleep(0.1) + win32api.keybd_event(win32con.VK_RETURN, 0, 0, 0) + time.sleep(0.01) + win32api.keybd_event(win32con.VK_RETURN, 0, win32con.KEYEVENTF_KEYUP, 0) + + self.log_callback(f"Sent command to PID {pid} ({app_name}) using clipboard") + + # Track this process for future use + self.active_processes[app_name.lower()] = pid + + return f"Command sent to PID {pid} ({app_name})" + + except Exception as e: + self.log_callback(f"Clipboard method failed: {e}") + return f"Failed to send command to PID {pid}: {e}" + + except Exception as e: + self.log_callback(f"Error in window interaction: {e}") + return f"Error sending command to PID {pid}: {e}" + + else: + return f"No visible window found for PID {pid}" + + return f"Command sending not implemented for {platform.system()}" + + except Exception as e: + self.log_callback(f"Error sending command to PID {pid}: {e}") + return f"Error sending command to PID {pid}: {e}" + + # Register all tools + tools["powershell"] = powershell + tools["bash"] = bash + tools["shell"] = shell + tools["open_app"] = open_app + tools["read_file"] = read_file + tools["write_file"] = write_file + tools["list_files"] = list_files + tools["execute_python"] = execute_python + tools["send_to_process"] = send_to_process + tools["send_to_pid"] = send_to_pid + tools["list_processes"] = list_processes + tools["get_window_processes"] = get_window_processes + + return tools + + def parse_and_execute(self, ai_response: str) -> str: + """Parse AI response for commands and execute them""" + output = [] + + # Look for command patterns in the response + lines = ai_response.split('\n') + + for line in lines: + # Check for explicit command markers + if line.strip().startswith("EXECUTE:"): + command = line.replace("EXECUTE:", "").strip() + result = self._execute_command(command) + output.append(f"[Executed] {command}\n{result}") + + elif line.strip().startswith("OPEN:"): + app = line.replace("OPEN:", "").strip() + result = self.tools["open_app"](app) + output.append(f"[Opened] {app}\n{result}") + + elif line.strip().startswith("PYTHON:"): + code = line.replace("PYTHON:", "").strip() + result = self.tools["execute_python"](code) + output.append(f"[Python] {result}") + + if output: + return "\n".join(output) + else: + # If no explicit commands, return the AI response as-is + return ai_response + + def _execute_command(self, command: str) -> str: + """Execute a command using appropriate shell""" + if platform.system() == "Windows": + # Detect if it's a PowerShell command + ps_commands = ["Get-", "Set-", "New-", "Remove-", "Start-", "Stop-", "Test-"] + if any(command.startswith(cmd) for cmd in ps_commands): + return self.tools["powershell"](command) + else: + return self.tools["shell"](command) + else: + return self.tools["bash"](command) + + def process_request(self, user_request: str, model_response: str) -> str: + """Process user request with model response""" + # First, check if the model response contains tool calls + if "EXECUTE:" in model_response or "OPEN:" in model_response or "PYTHON:" in model_response: + return self.parse_and_execute(model_response) + + # Otherwise, try to infer intent from user request + request_lower = user_request.lower() + + # Direct application opening requests + if "open" in request_lower: + if "word" in request_lower: + return self.tools["open_app"]("word") + elif "notepad" in request_lower: + return self.tools["open_app"]("notepad") + elif "powershell" in request_lower: + return self.tools["open_app"]("powershell") + elif "terminal" in request_lower or "cmd" in request_lower: + return self.tools["open_app"]("cmd") + elif "calculator" in request_lower: + return self.tools["open_app"]("calculator") + + # File operations + elif "list files" in request_lower or "show files" in request_lower: + return self.tools["list_files"]() + + # If we can't determine intent, return the model response + return model_response + + +def create_simple_agent(log_callback=None): + """Create a simple agent executor""" + return SimpleAgentExecutor(log_callback=log_callback) \ No newline at end of file diff --git a/simple_convert.py b/simple_convert.py new file mode 100644 index 0000000..e39be69 --- /dev/null +++ b/simple_convert.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +""" +Simple PNG to ICO converter using Windows tools +""" + +import os +import sys +import subprocess +from pathlib import Path + +def png_to_ico_windows(png_path, ico_path): + """Convert PNG to ICO using Windows PowerShell and .NET Image classes.""" + + png_path = str(Path(png_path).resolve()) + ico_path = str(Path(ico_path).resolve()) + + # PowerShell script to convert PNG to ICO + ps_script = f''' +Add-Type -AssemblyName System.Drawing +$png = [System.Drawing.Image]::FromFile("{png_path}") +$ico = New-Object System.Drawing.Icon($png.GetHbitmap(), 256, 256) +$ico.Save("{ico_path}") +$ico.Dispose() +$png.Dispose() +''' + + try: + # Try PowerShell method first + result = subprocess.run([ + 'powershell', '-Command', ps_script + ], capture_output=True, text=True, shell=True) + + if result.returncode == 0 and os.path.exists(ico_path): + print(f"Successfully converted {png_path} to {ico_path}") + return True + else: + print(f"PowerShell conversion failed: {result.stderr}") + return False + + except Exception as e: + print(f"Error during conversion: {e}") + return False + +def main(): + png_file = "assets/Halico.png" + ico_file = "assets/Halico.ico" + + if not os.path.exists(png_file): + print(f"PNG file not found: {png_file}") + return 1 + + print(f"Converting {png_file} to {ico_file}...") + + if png_to_ico_windows(png_file, ico_file): + print("Conversion completed successfully!") + return 0 + else: + print("Conversion failed!") + return 1 + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/splash_screen.py b/splash_screen.py new file mode 100644 index 0000000..1e00ec4 --- /dev/null +++ b/splash_screen.py @@ -0,0 +1,493 @@ +#!/usr/bin/env python3 +""" +DarkHal 2.0 Splash Screen + +A professional splash screen with logo, disclaimers, and loading animation. +""" + +import tkinter as tk +from tkinter import ttk +import os +import sys +import time +import threading +from pathlib import Path + +try: + from PIL import Image, ImageTk + PIL_AVAILABLE = True +except ImportError: + PIL_AVAILABLE = False + + +class SplashScreen: + """Professional splash screen for DarkHal 2.0""" + + def __init__(self, callback=None): + """ + Initialize splash screen. + + Args: + callback (callable): Function to call when splash closes with selected hardware acceleration + """ + self.callback = callback + self.selected_acceleration = None + self.root = tk.Tk() + self.showing_logo = True + # References for dynamic logo sizing + self.logo_original = None # PIL.Image (original) + self.logo_photo = None # ImageTk.PhotoImage or tk.PhotoImage (current) + self.logo_label = None # tk.Label showing the logo + # References for dynamic logo sizing + self.logo_original = None # PIL.Image (original) + self.logo_photo = None # ImageTk.PhotoImage or tk.PhotoImage (current) + self.logo_label = None # tk.Label that displays the logo + # References for dynamic logo sizing + self.logo_original = None # PIL.Image (original) + self.logo_photo = None # ImageTk.PhotoImage or tk.PhotoImage (current) + self.logo_label = None # tk.Label showing the logo + self.setup_window() + self.create_logo_screen() + + # Switch to main screen after 2.5 seconds + self.root.after(2500, self.switch_to_main_screen) + + def setup_window(self): + """Configure the splash window""" + # Remove window decorations + self.root.overrideredirect(True) + + # Set window size - larger to accommodate all elements + width = 700 + height = 600 + + # Center on screen + screen_width = self.root.winfo_screenwidth() + screen_height = self.root.winfo_screenheight() + x = (screen_width - width) // 2 + y = (screen_height - height) // 2 + + self.root.geometry(f"{width}x{height}+{x}+{y}") + self.root.configure(bg='#1a1a1a') # Dark background + + # Make window topmost + self.root.attributes('-topmost', True) + + # Set window icon if available + try: + icon_path = Path("assets/Halico.ico") + if icon_path.exists(): + self.root.iconbitmap(str(icon_path)) + except: + pass + + def load_logo(self): + """Load and prepare the logo image""" + logo_path = Path("assets/logo.png") + + print(f"Looking for logo at: {logo_path.absolute()}") + print(f"Logo exists: {logo_path.exists()}") + print(f"PIL available: {PIL_AVAILABLE}") + + if not logo_path.exists(): + print("Logo file not found") + return None + + try: + if PIL_AVAILABLE: + # Use Pillow for better image handling + image = Image.open(logo_path) + print(f"Original image size: {image.size}") + # Resize to fit in the fixed 75x75 box + image = image.resize((65, 65), Image.Resampling.LANCZOS) + print(f"Resized image to: {image.size}") + return ImageTk.PhotoImage(image) + else: + # Fallback to tkinter's basic image support + return tk.PhotoImage(file=str(logo_path)) + except Exception as e: + print(f"Error loading logo: {e}") + return None + + def create_logo_screen(self): + """Create the initial logo-only screen""" + # Clear any existing widgets + for widget in self.root.winfo_children(): + widget.destroy() + + # Center container for logo + center_frame = tk.Frame(self.root, bg='#1a1a1a') + center_frame.pack(expand=True, fill=tk.BOTH) + + logo_path = Path("assets/logo.png") + # Dynamic path if PIL available, else static fallback + if PIL_AVAILABLE: + self.logo_original = self._load_logo_original() + if self.logo_original is not None: + # Create label and center it; image set during update + self.logo_label = tk.Label(center_frame, bg='#1a1a1a', borderwidth=0, highlightthickness=0) + self.logo_label.place(relx=0.5, rely=0.5, anchor=tk.CENTER) + # Bind window resize to keep logo responsive + self.root.bind("", self.update_logo_size) + # Initial sizing + self.update_logo_size() + return + elif logo_path.exists(): + # PIL present but image failed -> fallback to static display using tk.PhotoImage + try: + self.logo_photo = tk.PhotoImage(file=str(logo_path)) + self.logo_label = tk.Label(center_frame, image=self.logo_photo, bg='#1a1a1a') + self.logo_label.place(relx=0.5, rely=0.5, anchor=tk.CENTER) + return + except Exception: + pass + else: + # No PIL -> static display if possible + if logo_path.exists(): + try: + self.logo_photo = tk.PhotoImage(file=str(logo_path)) + self.logo_label = tk.Label(center_frame, image=self.logo_photo, bg='#1a1a1a') + self.logo_label.place(relx=0.5, rely=0.5, anchor=tk.CENTER) + return + except Exception: + pass + + # Fallback - large app name + logo_text = tk.Label(center_frame, text="DarkHal 2.0", font=("Arial", 48, "bold"), fg='#00ff88', bg='#1a1a1a') + logo_text.place(relx=0.5, rely=0.5, anchor=tk.CENTER) + + def load_logo_large(self): + """Load logo at original size for logo screen""" + logo_path = Path("assets/logo.png") + + if not logo_path.exists(): + return None + + try: + if PIL_AVAILABLE: + image = Image.open(logo_path) + # Keep original size or resize to reasonable splash size + max_size = min(300, min(image.size)) + image = image.resize((max_size, max_size), Image.Resampling.LANCZOS) + return ImageTk.PhotoImage(image) + else: + return tk.PhotoImage(file=str(logo_path)) + except Exception as e: + print(f"Error loading large logo: {e}") + return None + + def switch_to_main_screen(self): + """Switch from logo screen to main text screen""" + self.showing_logo = False + # Stop handling resize events once the logo screen ends + try: + self.root.unbind("") + except Exception: + pass + # Stop handling resize events once the logo screen ends + try: + self.root.unbind("") + except Exception: + pass + try: + self.root.unbind("") + except Exception: + pass + self.create_widgets() + + def create_widgets(self): + """Create the main text-only screen with buttons""" + # Clear any existing widgets + for widget in self.root.winfo_children(): + widget.destroy() + + # Main container + main_frame = tk.Frame(self.root, bg='#1a1a1a') + main_frame.pack(fill=tk.BOTH, expand=True, padx=30, pady=20) + + # Application name + title_label = tk.Label( + main_frame, + text="DarkHal 2.0", + font=("Arial", 24, "bold"), + fg='#00ff88', # Green accent color + bg='#1a1a1a' + ) + title_label.pack(pady=(0, 5)) + + # Subtitle + subtitle_label = tk.Label( + main_frame, + text="AI Model Management & Training Platform", + font=("Arial", 11), + fg='#cccccc', + bg='#1a1a1a' + ) + subtitle_label.pack(pady=(0, 15)) + + # Warning/Disclaimer section + warning_frame = tk.Frame(main_frame, bg='#2a2a2a', relief=tk.RAISED, bd=1) + warning_frame.pack(fill=tk.X, pady=(0, 15)) + + warning_title = tk.Label( + warning_frame, + text="⚠️ IMPORTANT WARNING", + font=("Arial", 10, "bold"), + fg='#ff4444', + bg='#2a2a2a' + ) + warning_title.pack(pady=(8, 5)) + + warning_text = """This software is provided "as is" without any warranties or guarantees. +The user assumes all responsibility for the use of this software and any +consequences that may arise from its use. The developers are not liable +for any damages, data loss, or other issues that may occur.""" + + warning_label = tk.Label( + warning_frame, + text=warning_text, + font=("Arial", 9), + fg='#cccccc', + bg='#2a2a2a', + wraplength=600, + justify=tk.CENTER + ) + warning_label.pack(pady=(0, 8)) + + # Terms agreement section + terms_frame = tk.Frame(main_frame, bg='#1a1a1a') + terms_frame.pack(pady=(10, 15)) + + terms_label = tk.Label( + terms_frame, + text="By continuing you agree to our terms", + font=("Arial", 10), + fg='#ffaa00', + bg='#1a1a1a' + ) + terms_label.pack() + + # Agreement section + agreement_frame = tk.Frame(main_frame, bg='#1a1a1a') + agreement_frame.pack(pady=(10, 20)) + + agreement_label = tk.Label( + agreement_frame, + text="I agree", + font=("Arial", 12, "bold"), + fg='#00ff88', + bg='#1a1a1a' + ) + agreement_label.pack(pady=(0, 15)) + + # Hardware acceleration buttons + buttons_frame = tk.Frame(agreement_frame, bg='#1a1a1a') + buttons_frame.pack() + + # CUDA button + cuda_btn = tk.Button( + buttons_frame, + text="Start with CUDA", + font=("Arial", 11, "bold"), + bg='#00aa55', + fg='white', + activebackground='#00ff88', + activeforeground='white', + relief=tk.RAISED, + bd=2, + padx=20, + pady=8, + command=lambda: self.start_application('cuda') + ) + cuda_btn.pack(side=tk.LEFT, padx=(0, 10)) + + # Intel button + intel_btn = tk.Button( + buttons_frame, + text="Start with Intel", + font=("Arial", 11, "bold"), + bg='#0078d4', + fg='white', + activebackground='#106ebe', + activeforeground='white', + relief=tk.RAISED, + bd=2, + padx=20, + pady=8, + command=lambda: self.start_application('intel') + ) + intel_btn.pack(side=tk.LEFT, padx=5) + + # CPU button + cpu_btn = tk.Button( + buttons_frame, + text="Start with CPU", + font=("Arial", 11, "bold"), + bg='#666666', + fg='white', + activebackground='#888888', + activeforeground='white', + relief=tk.RAISED, + bd=2, + padx=20, + pady=8, + command=lambda: self.start_application('cpu') + ) + cpu_btn.pack(side=tk.LEFT, padx=(10, 0)) + + # Copyright section at bottom + copyright_frame = tk.Frame(main_frame, bg='#1a1a1a') + copyright_frame.pack(side=tk.BOTTOM, fill=tk.X, pady=(20, 0)) + + copyright_label = tk.Label( + copyright_frame, + text="© 2025 Setec Labs", + font=("Arial", 9, "bold"), + fg='#888888', + bg='#1a1a1a' + ) + copyright_label.pack(side=tk.LEFT) + + author_label = tk.Label( + copyright_frame, + text="by ssSnake", + font=("Arial", 9, "italic"), + fg='#888888', + bg='#1a1a1a' + ) + author_label.pack(side=tk.RIGHT) + + def start_application(self, acceleration_type): + """Start the application with selected hardware acceleration""" + self.selected_acceleration = acceleration_type + print(f"Starting DarkHal 2.0 with {acceleration_type.upper()} acceleration...") + self.close_splash() + + def close_splash(self): + """Close the splash screen and call callback""" + self.root.destroy() + + if self.callback: + self.callback(self.selected_acceleration) + + def show(self): + """Show the splash screen""" + self.root.mainloop() + + # --- Dynamic logo helpers --- + def _load_logo_original(self): + """Load the original logo as a PIL Image (RGBA) without resizing.""" + logo_path = Path("assets/logo.png") + if not logo_path.exists(): + return None + try: + image = Image.open(logo_path) + return image.convert("RGBA") + except Exception as e: + print(f"Error loading original logo: {e}") + return None + + def update_logo_size(self, event=None): + """Resize the logo image dynamically to match the window size.""" + if not self.showing_logo: + return + if self.logo_label is None or self.logo_original is None: + return + try: + # Current window size + win_w = max(1, self.root.winfo_width()) + win_h = max(1, self.root.winfo_height()) + + # Target size: fraction of shortest side + target_box = int(min(win_w, win_h) * 0.45) + # Clamp to reasonable bounds + target_box = max(96, min(512, target_box)) + + ow, oh = self.logo_original.size + scale = min(target_box / ow, target_box / oh) + new_w = max(1, int(ow * scale)) + new_h = max(1, int(oh * scale)) + + resized = self.logo_original.resize((new_w, new_h), Image.Resampling.LANCZOS) + self.logo_photo = ImageTk.PhotoImage(resized) + self.logo_label.configure(image=self.logo_photo) + except Exception as e: + print(f"Error dynamically resizing logo: {e}") + """Resize the logo to fit dynamically within the current window size.""" + if not self.showing_logo: + return # Only active on the logo screen + if self.logo_label is None or self.logo_original is None: + return + try: + # Current window size + win_w = max(1, self.root.winfo_width()) + win_h = max(1, self.root.winfo_height()) + + # Target box as a fraction of the shortest side + target_box = int(min(win_w, win_h) * 0.45) + # Clamp to reasonable bounds + target_box = max(96, min(512, target_box)) + + ow, oh = self.logo_original.size + scale = min(target_box / ow, target_box / oh) + new_w = max(1, int(ow * scale)) + new_h = max(1, int(oh * scale)) + + # Resize with high-quality resampling + resized = self.logo_original.resize((new_w, new_h), Image.Resampling.LANCZOS) + self.logo_photo = ImageTk.PhotoImage(resized) + self.logo_label.configure(image=self.logo_photo) + except Exception as e: + print(f"Error dynamically resizing logo: {e}") + +class SplashManager: + """Manager for showing splash screen and launching main application""" + + def __init__(self, main_app_callback=None): + """ + Initialize splash manager. + + Args: + main_app_callback (callable): Function to call when splash closes with acceleration type + """ + self.main_app_callback = main_app_callback + + def show_splash_and_launch(self): + """Show splash screen then launch main application""" + splash = SplashScreen(callback=self.launch_main_app) + splash.show() + + def launch_main_app(self, acceleration_type): + """Launch the main application after splash""" + if self.main_app_callback: + self.main_app_callback(acceleration_type) + else: + print(f"DarkHal 2.0 ready with {acceleration_type.upper() if acceleration_type else 'default'} acceleration!") + + +def demo_main_app(acceleration_type): + """Demo main application for testing""" + root = tk.Tk() + root.title(f"DarkHal 2.0 - Main Application ({acceleration_type.upper()})") + root.geometry("800x600") + + label = tk.Label(root, text=f"Welcome to DarkHal 2.0!\nRunning with {acceleration_type.upper()} acceleration", + font=("Arial", 16), justify=tk.CENTER) + label.pack(expand=True) + + root.mainloop() + + +def main(): + """Main entry point for testing splash screen""" + print("Starting DarkHal 2.0 with splash screen...") + + # Create splash manager + splash_manager = SplashManager(main_app_callback=demo_main_app) + + # Show splash and launch app + splash_manager.show_splash_and_launch() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test_chess_integration.py b/test_chess_integration.py new file mode 100644 index 0000000..4290921 --- /dev/null +++ b/test_chess_integration.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +""" +Integration test for the chess game with LLM. +This simulates the chess logic without the GUI. +""" + +import sys +import os +sys.path.append('.') + +try: + import chess + print("✓ python-chess imported successfully") +except ImportError: + print("✗ python-chess not available") + sys.exit(1) + +class MockSettings: + """Mock settings manager for testing.""" + def __init__(self): + self.settings = { + 'paths.last_model_path': '/path/to/model.gguf', # Mock path + 'model_settings.default_n_ctx': 4096, + 'model_settings.default_n_gpu_layers': 0 + } + + def get(self, key, default=None): + return self.settings.get(key, default) + +class MockLlama: + """Mock LLM for testing chess integration.""" + def __init__(self, *args, **kwargs): + self.moves_to_suggest = ['e2e4', 'g1f3', 'd2d4', 'b1c3'] # Common opening moves + self.call_count = 0 + + def __call__(self, prompt, **kwargs): + """Simulate LLM response with chess moves.""" + # Cycle through suggested moves + move = self.moves_to_suggest[self.call_count % len(self.moves_to_suggest)] + self.call_count += 1 + + print(f"Mock LLM responding with: {move}") + return { + 'choices': [{'text': move}] + } + +class ChessGameSimulator: + """Simplified chess game simulator based on our implementation.""" + + def __init__(self): + self.board = chess.Board() + self.settings = MockSettings() + self.llm_cache = None + self.move_history = [] + + def _get_llm_instance(self): + """Get mock LLM instance.""" + if not self.llm_cache: + # Use mock instead of real LLM + self.llm_cache = MockLlama() + return self.llm_cache + + def _parse_move_from_response(self, text): + """Parse chess move from LLM response using multiple methods.""" + text = text.strip().lower() + legal_moves = list(self.board.legal_moves) + legal_uci = [move.uci() for move in legal_moves] + + # Method 1: Direct UCI format match + for move_uci in legal_uci: + if move_uci in text: + return chess.Move.from_uci(move_uci) + + # Method 2: Look for 4-5 character sequences that could be UCI + import re + uci_pattern = r'\b[a-h][1-8][a-h][1-8][qrbn]?\b' + matches = re.findall(uci_pattern, text) + for match in matches: + if match in legal_uci: + return chess.Move.from_uci(match) + + # Method 3: Try to find SAN notation and convert + for move in legal_moves: + san = self.board.san(move).lower() + san_clean = san.replace('+', '').replace('#', '').replace('x', '') + if san_clean in text or san in text: + return move + + return None + + def _create_chess_prompt(self): + """Create a chess-specific prompt for the LLM.""" + fen = self.board.fen() + legal_moves = [move.uci() for move in self.board.legal_moves] + move_history = [self.board.san(move) for move in self.board.move_stack[-6:]] + + ai_color = "white" if self.board.turn == chess.WHITE else "black" + current_turn = "white" if self.board.turn == chess.WHITE else "black" + board_unicode = self.board.unicode() + + prompt = f"""You are a professional chess player and you play as {ai_color}. Now is your turn to make a move. + +Current board position: +{board_unicode} + +Position (FEN): {fen} +Turn: {current_turn} +Recent moves: {' '.join(move_history) if move_history else 'Game start'} + +Legal moves available (UCI format): {', '.join(legal_moves)} + +As an expert chess player, choose the BEST move considering: +- King safety and piece protection +- Center control and piece development +- Tactical opportunities (captures, forks, pins, skewers) +- Positional advantages +- Endgame principles if material is low + +Reply with ONLY the move in UCI format (examples: e2e4, g1f3, e7e8q):""" + + return prompt + + def _query_llm_for_move(self): + """Query the LLM for a chess move.""" + try: + llm = self._get_llm_instance() + if not llm: + return None + + prompt = self._create_chess_prompt() + + # Use multiple attempts with different temperatures + max_attempts = 3 + temperatures = [0.1, 0.3, 0.5] + + for attempt in range(max_attempts): + try: + response = llm( + prompt, + max_tokens=20, + temperature=temperatures[attempt], + stop=["\n", " ", ".", ",", "because", "since", "as", "the"], + echo=False + ) + + text = response['choices'][0]['text'].strip().lower() + move = self._parse_move_from_response(text) + + if move and move in self.board.legal_moves: + print(f"✓ LLM suggested valid move: {move.uci()} (attempt {attempt + 1})") + return move + elif move: + print(f"✗ LLM suggested illegal move: {move.uci()}") + else: + print(f"✗ Could not parse move from response: '{text}'") + + except Exception as e: + print(f"✗ LLM attempt {attempt + 1} failed: {e}") + continue + + return None + + except Exception as e: + print(f"✗ LLM query failed: {e}") + return None + + def _get_strategic_move(self, legal_moves): + """Get strategic move using chess heuristics.""" + import random + scored_moves = [] + + for move in legal_moves: + score = 0 + + # Prefer central squares + to_square = move.to_square + file = chess.square_file(to_square) + rank = chess.square_rank(to_square) + center_distance = abs(3.5 - file) + abs(3.5 - rank) + score += (7 - center_distance) * 2 + + # Prefer piece development + piece = self.board.piece_at(move.from_square) + if piece and piece.piece_type in [chess.KNIGHT, chess.BISHOP]: + if chess.square_rank(move.from_square) in [0, 7]: + score += 15 + + scored_moves.append((move, score)) + + scored_moves.sort(key=lambda x: x[1] + random.random() * 2, reverse=True) + return scored_moves[0][0] + + def get_ai_move(self): + """Get AI move with LLM integration and fallback.""" + legal_moves = list(self.board.legal_moves) + if not legal_moves: + return None + + # Try LLM first + llm_move = self._query_llm_for_move() + if llm_move and llm_move in legal_moves: + return llm_move + + # Fallback to strategic heuristics + print("Using strategic fallback move") + return self._get_strategic_move(legal_moves) + + def make_move(self, move): + """Make a move on the board.""" + if move in self.board.legal_moves: + san = self.board.san(move) + self.board.push(move) + self.move_history.append(san) + return True + return False + + def play_moves(self, num_moves=4): + """Play a few moves to test the integration.""" + print(f"\n=== Playing {num_moves} moves ===") + + for i in range(num_moves): + if self.board.is_game_over(): + print("Game over!") + break + + current_turn = "White" if self.board.turn == chess.WHITE else "Black" + print(f"\nMove {i+1} - {current_turn} to play") + print(f"Current position: {self.board.fen()}") + + ai_move = self.get_ai_move() + if ai_move: + success = self.make_move(ai_move) + if success: + print(f"✓ Played: {self.move_history[-1]} ({ai_move.uci()})") + else: + print(f"✗ Failed to make move: {ai_move.uci()}") + break + else: + print("✗ No move found") + break + + print(f"\nFinal position after {len(self.move_history)} moves:") + print(self.board.unicode()) + print(f"Move history: {' '.join(self.move_history)}") + +def main(): + """Main test function.""" + print("Testing Chess Integration with Mock LLM") + print("=" * 50) + + try: + game = ChessGameSimulator() + game.play_moves(6) # Play 6 moves to test the integration + + print("\n" + "=" * 50) + print("✓ Chess integration test completed successfully!") + print("\nKey features tested:") + print("- LLM querying with multiple attempts") + print("- Move parsing from LLM responses") + print("- Strategic fallback when LLM fails") + print("- Proper move validation") + print("- Chess position tracking") + + except Exception as e: + print(f"✗ Integration test failed: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test_chess_logic.py b/test_chess_logic.py new file mode 100644 index 0000000..d745ef1 --- /dev/null +++ b/test_chess_logic.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +""" +Test script for the improved chess logic functionality. +""" + +import sys +sys.path.append('.') + +try: + import chess + print("✓ python-chess library imported successfully") +except ImportError: + print("✗ python-chess library not available") + sys.exit(1) + +# Test basic chess functionality +def test_chess_functionality(): + """Test basic chess board and move functionality.""" + print("\n=== Testing Chess Functionality ===") + + # Create board + board = chess.Board() + print(f"✓ Chess board created: {board.fen()}") + + # Test Unicode display + print("✓ Board Unicode representation:") + print(board.unicode()) + + # Test legal moves + legal_moves = list(board.legal_moves) + print(f"✓ Legal moves available: {len(legal_moves)}") + print(f" First 5 moves: {[move.uci() for move in legal_moves[:5]]}") + + # Test making a move + first_move = legal_moves[0] + board.push(first_move) + print(f"✓ Made move: {first_move.uci()}") + print(f" New position: {board.fen()}") + + # Test SAN notation + board.pop() # Undo the move + san_notation = board.san(first_move) + print(f"✓ SAN notation for {first_move.uci()}: {san_notation}") + + return True + +def test_move_parsing(): + """Test the move parsing logic from our improved chess implementation.""" + print("\n=== Testing Move Parsing Logic ===") + + # Simulate the _parse_move_from_response method + board = chess.Board() + legal_moves = list(board.legal_moves) + legal_uci = [move.uci() for move in legal_moves] + + test_responses = [ + "e2e4", # Direct UCI + "I think e2e4 is good", # UCI in sentence + "Move Nf3", # SAN notation + "The best move is g1f3", # UCI in sentence + "1. e4", # SAN with number + "e4", # Short SAN + ] + + def parse_move_from_response(text): + """Simplified version of our parsing method.""" + text = text.strip().lower() + + # Method 1: Direct UCI format match + for move_uci in legal_uci: + if move_uci in text: + return chess.Move.from_uci(move_uci) + + # Method 2: Try SAN notation + for move in legal_moves: + san = board.san(move).lower() + san_clean = san.replace('+', '').replace('#', '').replace('x', '') + if san_clean in text or san in text: + return move + + return None + + for response in test_responses: + parsed_move = parse_move_from_response(response) + if parsed_move: + print(f"✓ Parsed '{response}' -> {parsed_move.uci()} ({board.san(parsed_move)})") + else: + print(f"✗ Failed to parse '{response}'") + + return True + +def test_chess_prompt_creation(): + """Test chess prompt creation logic.""" + print("\n=== Testing Chess Prompt Creation ===") + + board = chess.Board() + legal_moves = [move.uci() for move in board.legal_moves] + move_history = [] # Empty for starting position + + # Simulate prompt creation + ai_color = "black" + current_turn = "white" + board_unicode = board.unicode() + fen = board.fen() + + prompt = f"""You are a professional chess player and you play as {ai_color}. Now is your turn to make a move. + +Current board position: +{board_unicode} + +Position (FEN): {fen} +Turn: {current_turn} +Recent moves: {' '.join(move_history) if move_history else 'Game start'} + +Legal moves available (UCI format): {', '.join(legal_moves)} + +As an expert chess player, choose the BEST move considering: +- King safety and piece protection +- Center control and piece development +- Tactical opportunities (captures, forks, pins, skewers) +- Positional advantages +- Endgame principles if material is low + +Reply with ONLY the move in UCI format (examples: e2e4, g1f3, e7e8q):""" + + print(f"✓ Chess prompt created successfully ({len(prompt)} characters)") + print("✓ Prompt includes:") + print(" - Board Unicode representation") + print(" - FEN position") + print(" - Legal moves list") + print(" - Strategic guidance") + print(" - Clear output format instruction") + + return True + +def test_strategic_move_evaluation(): + """Test basic strategic move evaluation.""" + print("\n=== Testing Strategic Move Evaluation ===") + + board = chess.Board() + legal_moves = list(board.legal_moves) + + # Simplified scoring like in our implementation + scored_moves = [] + + for move in legal_moves: + score = 0 + + # Prefer central squares + to_square = move.to_square + file = chess.square_file(to_square) + rank = chess.square_rank(to_square) + center_distance = abs(3.5 - file) + abs(3.5 - rank) + score += (7 - center_distance) * 2 + + # Prefer piece development + piece = board.piece_at(move.from_square) + if piece and piece.piece_type in [chess.KNIGHT, chess.BISHOP]: + if chess.square_rank(move.from_square) in [0, 7]: # From back rank + score += 15 + + scored_moves.append((move, score)) + + # Sort by score + scored_moves.sort(key=lambda x: x[1], reverse=True) + + print(f"✓ Evaluated {len(scored_moves)} moves") + print(" Top 3 moves by strategic score:") + for i, (move, score) in enumerate(scored_moves[:3]): + san = board.san(move) + print(f" {i+1}. {san} ({move.uci()}) - Score: {score}") + + return True + +if __name__ == "__main__": + print("Testing Improved Chess Implementation") + print("=" * 50) + + try: + success = True + success &= test_chess_functionality() + success &= test_move_parsing() + success &= test_chess_prompt_creation() + success &= test_strategic_move_evaluation() + + print("\n" + "=" * 50) + if success: + print("✓ All tests passed! Chess integration is working correctly.") + print("\nKey improvements implemented:") + print("- Multi-attempt LLM querying with temperature variation") + print("- Robust move parsing from LLM responses") + print("- Structured chess prompts like llm_chess project") + print("- Better error handling and fallback mechanisms") + print("- Strategic move evaluation as backup") + print("- Proper move validation and logging") + else: + print("✗ Some tests failed.") + + except Exception as e: + print(f"✗ Test suite failed with error: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/test_kv_cache.py b/test_kv_cache.py new file mode 100644 index 0000000..2ccd39a --- /dev/null +++ b/test_kv_cache.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +""" +Test script to validate KV caching functionality across different model formats. +This script tests the complete KV caching system implementation. +""" + +import os +import sys +import time +import torch +import psutil +from typing import List, Dict, Any + +# Add the project root to the Python path +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from llm_runtime import load_model, GenerateConfig +from llm_runtime.chat_session import get_session_manager + + +def monitor_gpu_memory(): + """Monitor GPU memory usage if CUDA is available""" + if torch.cuda.is_available(): + return torch.cuda.memory_allocated() / 1024**3 # GB + return 0.0 + + +def monitor_cpu_memory(): + """Monitor CPU memory usage""" + process = psutil.Process(os.getpid()) + return process.memory_info().rss / 1024**3 # GB + + +def test_transformers_kv_cache(): + """Test KV caching with a Transformers model""" + print("=" * 60) + print("Testing Transformers KV Cache") + print("=" * 60) + + # Test with a small model for demonstration + model_path = "microsoft/DialoGPT-small" # Small model for testing + + try: + print("📥 Loading model...") + + start_time = time.time() + model = load_model(model_path, device_map="auto") + load_time = time.time() - start_time + + print(f"✅ Model loaded in {load_time:.2f} seconds") + print(f"📊 Initial GPU memory: {monitor_gpu_memory():.2f} GB") + print(f"📊 Initial CPU memory: {monitor_cpu_memory():.2f} GB") + + # Test session management + session_manager = get_session_manager() + session_id = "test_session" + + # Get model info if available + if hasattr(model, 'get_model_info'): + model_info = model.get_model_info() + print(f"📋 Model info: {model_info}") + + # Test 1: First generation (prefill phase) + print("\n🔄 Test 1: First generation (should trigger prefill)") + prompt1 = "Hello, how are you?" + cfg = GenerateConfig(max_tokens=50, temperature=0.7) + + start_time = time.time() + response1 = model.generate(prompt1, cfg=cfg, session_id=session_id) + gen1_time = time.time() - start_time + + print(f"🤖 Response: {response1}") + print(f"⏱️ Generation time: {gen1_time:.2f} seconds") + print(f"📊 GPU memory after gen1: {monitor_gpu_memory():.2f} GB") + + # Check session state + if hasattr(model, 'get_session_info'): + session_info = model.get_session_info(session_id) + print(f"📋 Session info after gen1: {session_info}") + + # Test 2: Second generation (should use cache) + print("\n🔄 Test 2: Second generation (should use KV cache)") + prompt2 = prompt1 + " " + response1 + " Tell me more." + + start_time = time.time() + response2 = model.generate(prompt2, cfg=cfg, session_id=session_id) + gen2_time = time.time() - start_time + + print(f"🤖 Response: {response2}") + print(f"⏱️ Generation time: {gen2_time:.2f} seconds") + print(f"📊 GPU memory after gen2: {monitor_gpu_memory():.2f} GB") + + # Check session state again + if hasattr(model, 'get_session_info'): + session_info = model.get_session_info(session_id) + print(f"📋 Session info after gen2: {session_info}") + + # Test 3: Streaming generation with cache + print("\n🔄 Test 3: Streaming generation with KV cache") + prompt3 = prompt2 + " " + response2 + " What's your favorite color?" + + print("🤖 Streaming response: ", end="", flush=True) + start_time = time.time() + streamed_tokens = [] + + for token in model.stream(prompt3, cfg=cfg, session_id=session_id): + print(token, end="", flush=True) + streamed_tokens.append(token) + + stream_time = time.time() - start_time + response3 = "".join(streamed_tokens) + + print(f"\n⏱️ Streaming time: {stream_time:.2f} seconds") + print(f"📊 Final GPU memory: {monitor_gpu_memory():.2f} GB") + + # Performance analysis + print(f"\n📈 Performance Analysis:") + print(f" First generation (prefill): {gen1_time:.2f}s") + print(f" Second generation (cached): {gen2_time:.2f}s") + print(f" Third generation (streamed): {stream_time:.2f}s") + + if gen2_time < gen1_time * 0.8: # Expect at least 20% improvement + print("✅ KV caching appears to be working (faster subsequent generations)") + else: + print("⚠️ KV caching may not be providing expected speedup") + + # Test 4: Cache invalidation + print("\n🔄 Test 4: Cache invalidation test") + if hasattr(model, 'clear_session_cache'): + model.clear_session_cache(session_id) + print("✅ Cache cleared") + + # New generation after cache clear + start_time = time.time() + response4 = model.generate("This is a completely new conversation.", cfg=cfg, session_id=session_id) + gen4_time = time.time() - start_time + + print(f"🤖 Response after cache clear: {response4}") + print(f"⏱️ Generation time after cache clear: {gen4_time:.2f} seconds") + + print("✅ Transformers KV cache test completed successfully") + return True + + except Exception as e: + print(f"❌ Transformers KV cache test failed: {e}") + import traceback + traceback.print_exc() + return False + + +def test_gguf_compatibility(): + """Test that GGUF models still work with their built-in caching""" + print("\n" + "=" * 60) + print("Testing GGUF Model Compatibility") + print("=" * 60) + + # Look for a GGUF model in the models directory + models_dir = "/mnt/c/Users/mdavi/PycharmProjects/LLM_Train/models" + gguf_model = None + + if os.path.exists(models_dir): + for file in os.listdir(models_dir): + if file.endswith('.gguf'): + gguf_model = os.path.join(models_dir, file) + break + + if not gguf_model: + print("⚠️ No GGUF model found for testing - skipping GGUF test") + return True + + try: + print("📥 Loading GGUF model...") + + start_time = time.time() + model = load_model(gguf_model, n_ctx=4096, n_gpu_layers=0) + load_time = time.time() - start_time + + print(f"✅ GGUF Model loaded in {load_time:.2f} seconds") + + # Test generation + prompt = "Hello, what is artificial intelligence?" + cfg = GenerateConfig(max_tokens=50, temperature=0.7) + + start_time = time.time() + response = model.generate(prompt, cfg=cfg) + gen_time = time.time() - start_time + + print(f"🤖 Response: {response}") + print(f"⏱️ Generation time: {gen_time:.2f} seconds") + print("✅ GGUF compatibility test completed successfully") + + return True + + except Exception as e: + print(f"❌ GGUF compatibility test failed: {e}") + import traceback + traceback.print_exc() + return False + + +def main(): + """Run all KV cache tests""" + print("🚀 Starting KV Cache Validation Tests") + print(f"🔧 PyTorch version: {torch.__version__}") + print(f"🔧 CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"🔧 CUDA device: {torch.cuda.get_device_name()}") + + print(f"💾 Initial CPU memory: {monitor_cpu_memory():.2f} GB") + print(f"🖥️ Initial GPU memory: {monitor_gpu_memory():.2f} GB") + + results = [] + + # Test 1: Transformers KV caching + results.append(test_transformers_kv_cache()) + + # Test 2: GGUF compatibility + results.append(test_gguf_compatibility()) + + # Summary + print("\n" + "=" * 60) + print("TEST SUMMARY") + print("=" * 60) + + passed = sum(results) + total = len(results) + + print(f"✅ Tests passed: {passed}/{total}") + + if passed == total: + print("🎉 All KV caching tests passed!") + print("\n📋 KV Caching System Features Validated:") + print(" ✅ Multi-format support (Transformers + GGUF)") + print(" ✅ Persistent KV cache across turns") + print(" ✅ Prefill and decode phases") + print(" ✅ Session management") + print(" ✅ Cache invalidation") + print(" ✅ Streaming support with cache") + print(" ✅ GPU utilization during inference") + else: + print("❌ Some tests failed - check implementation") + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test_windows_chess.py b/test_windows_chess.py new file mode 100644 index 0000000..e3ac067 --- /dev/null +++ b/test_windows_chess.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +""" +Windows-specific test for chess functionality. +Tests the core chess logic without GUI dependencies. +""" + +import sys +import os + +# Add the project root to the path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +def test_chess_imports(): + """Test that all required chess dependencies are available.""" + print("Testing Chess Dependencies for Windows") + print("=" * 50) + + # Test python-chess + try: + import chess + print("✓ python-chess library imported successfully") + + # Test basic functionality + board = chess.Board() + print(f"✓ Chess board created: {board.fen()}") + + legal_moves = list(board.legal_moves) + print(f"✓ Found {len(legal_moves)} legal opening moves") + + # Test Unicode display (important for Windows console) + unicode_board = board.unicode() + print("✓ Unicode board representation works") + + return True + + except ImportError as e: + print(f"✗ python-chess not available: {e}") + print(" Install with: pip install python-chess") + return False + except Exception as e: + print(f"✗ Chess library error: {e}") + return False + +def test_llama_cpp_availability(): + """Test if llama-cpp-python is available (optional for testing).""" + try: + import llama_cpp + print("✓ llama-cpp-python is available") + print(f" Version info: {llama_cpp.__version__ if hasattr(llama_cpp, '__version__') else 'Unknown'}") + return True + except ImportError: + print("⚠ llama-cpp-python not available (optional for testing)") + print(" For full LLM functionality, install with: pip install llama-cpp-python") + return False + except Exception as e: + print(f"⚠ llama-cpp-python error: {e}") + return False + +def test_darkhal_imports(): + """Test that our DarkHal chess modules import correctly.""" + try: + # Test settings manager + from darkhal.settings_manager import SettingsManager + settings = SettingsManager() + print("✓ DarkHal SettingsManager imported successfully") + + # Test if we can create a mock chess environment + print("✓ DarkHal chess infrastructure is ready") + return True + + except ImportError as e: + print(f"✗ DarkHal module import error: {e}") + return False + except Exception as e: + print(f"✗ DarkHal initialization error: {e}") + return False + +def test_chess_ai_logic(): + """Test the core chess AI logic we implemented.""" + print("\nTesting Chess AI Logic") + print("-" * 30) + + try: + import chess + + # Test our move parsing logic + def parse_move_from_response(text, board): + """Simplified version of our parsing method.""" + text = text.strip().lower() + legal_moves = list(board.legal_moves) + legal_uci = [move.uci() for move in legal_moves] + + # Direct UCI format match + for move_uci in legal_uci: + if move_uci in text: + return chess.Move.from_uci(move_uci) + + # Try SAN notation + for move in legal_moves: + san = board.san(move).lower() + san_clean = san.replace('+', '').replace('#', '').replace('x', '') + if san_clean in text or san in text: + return move + + return None + + # Test cases + board = chess.Board() + test_cases = [ + "e2e4", + "I suggest e2e4", + "Nf3 is good", + "play d2d4", + ] + + print("Testing move parsing:") + for test in test_cases: + move = parse_move_from_response(test, board) + if move: + print(f" ✓ '{test}' → {move.uci()} ({board.san(move)})") + else: + print(f" ✗ '{test}' → No valid move found") + + # Test strategic move evaluation + def get_strategic_move(legal_moves, board): + """Test strategic move selection.""" + scored_moves = [] + + for move in legal_moves: + score = 0 + + # Prefer central squares + to_square = move.to_square + file = chess.square_file(to_square) + rank = chess.square_rank(to_square) + center_distance = abs(3.5 - file) + abs(3.5 - rank) + score += (7 - center_distance) * 2 + + # Prefer piece development + piece = board.piece_at(move.from_square) + if piece and piece.piece_type in [chess.KNIGHT, chess.BISHOP]: + if chess.square_rank(move.from_square) in [0, 7]: + score += 15 + + scored_moves.append((move, score)) + + scored_moves.sort(key=lambda x: x[1], reverse=True) + return scored_moves[0][0] if scored_moves else None + + legal_moves = list(board.legal_moves) + best_move = get_strategic_move(legal_moves, board) + if best_move: + print(f"✓ Strategic move selection: {best_move.uci()} ({board.san(best_move)})") + else: + print("✗ Strategic move selection failed") + + return True + + except Exception as e: + print(f"✗ Chess AI logic test failed: {e}") + return False + +def test_windows_console(): + """Test Windows console compatibility.""" + print("\nTesting Windows Console Compatibility") + print("-" * 40) + + try: + import chess + board = chess.Board() + + # Test Unicode chess pieces display + print("Testing Unicode chess pieces:") + pieces = { + chess.PAWN: "♟♙", chess.ROOK: "♜♖", chess.KNIGHT: "♞♘", + chess.BISHOP: "♝♗", chess.QUEEN: "♛♕", chess.KING: "♚♔" + } + + for piece_type, symbols in pieces.items(): + print(f" {chess.piece_name(piece_type).title()}: {symbols}") + + print("\nSample board position:") + print(board.unicode()) + + print("✓ Windows console Unicode display working") + return True + + except Exception as e: + print(f"✗ Windows console test failed: {e}") + return False + +def main(): + """Run all tests.""" + print("DarkHal Chess - Windows Compatibility Test") + print("=" * 60) + + results = [] + + # Run tests + results.append(("Chess Dependencies", test_chess_imports())) + results.append(("LLM Dependencies", test_llama_cpp_availability())) + results.append(("DarkHal Modules", test_darkhal_imports())) + results.append(("Chess AI Logic", test_chess_ai_logic())) + results.append(("Windows Console", test_windows_console())) + + # Summary + print("\n" + "=" * 60) + print("TEST SUMMARY") + print("=" * 60) + + passed = 0 + total = len(results) + + for test_name, result in results: + status = "PASS" if result else "FAIL" + icon = "✓" if result else "✗" + print(f"{icon} {test_name:<25} {status}") + if result: + passed += 1 + + print("-" * 60) + print(f"Total: {passed}/{total} tests passed") + + if passed == total: + print("\n🎉 All tests passed! Your chess implementation is ready for Windows.") + print("\nNext steps:") + print("1. Load a chess model (any LLM works)") + print("2. Launch DarkHal and try the chess feature") + print("3. Play against the AI!") + else: + print(f"\n⚠ {total - passed} test(s) failed. Please install missing dependencies:") + if not any(name == "Chess Dependencies" and result for name, result in results): + print(" pip install python-chess") + if not any(name == "LLM Dependencies" and result for name, result in results): + print(" pip install llama-cpp-python") + + return passed == total + +if __name__ == "__main__": + success = main() + + # Keep window open on Windows + if os.name == 'nt': # Windows + input("\nPress Enter to exit...") + + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/tools/debug_torch.py b/tools/debug_torch.py new file mode 100644 index 0000000..d7e2574 --- /dev/null +++ b/tools/debug_torch.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +""" +Quick sanity check for PyTorch and CUDA setup +""" +import torch + +def check_torch_setup(): + """Check PyTorch and CUDA configuration""" + print("=== PyTorch & CUDA Debug Info ===") + print("torch:", torch.__version__) + print("cuda available:", torch.cuda.is_available()) + print("cuda runtime:", torch.version.cuda) + + if torch.cuda.is_available(): + print("gpu count:", torch.cuda.device_count(), "name:", torch.cuda.get_device_name(0)) + print("cuda arch list:", torch.cuda.get_arch_list()) + for i in range(torch.cuda.device_count()): + props = torch.cuda.get_device_properties(i) + print(f"GPU {i}: {props.name}, Memory: {props.total_memory/1024**3:.1f}GB") + else: + print("WARNING: CUDA not available - you may have CPU-only torch") + print("To fix: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121") + + # Test basic tensor operations + try: + x = torch.randn(2, 2) + if torch.cuda.is_available(): + x_gpu = x.cuda() + print("✓ Basic CUDA tensor operations work") + print("✓ Basic CPU tensor operations work") + except Exception as e: + print(f"✗ Tensor operations failed: {e}") + +if __name__ == "__main__": + check_torch_setup() \ No newline at end of file diff --git a/tools/inspect_devices.py b/tools/inspect_devices.py new file mode 100644 index 0000000..b4e90e4 --- /dev/null +++ b/tools/inspect_devices.py @@ -0,0 +1,160 @@ +import torch +import re +from transformers import AutoModelForCausalLM, AutoTokenizer + +def dtype_nbytes(dt: torch.dtype) -> int: + return { + torch.float32: 4, torch.float: 4, + torch.float16: 2, torch.bfloat16: 2, + torch.int8: 1, torch.uint8: 1, + torch.int4: 0.5, # pseudo for 4-bit quant libs + }.get(dt, 4) + +def pretty_bytes(n: float) -> str: + for u in ["B","KB","MB","GB","TB"]: + if n < 1024 or u == "TB": return f"{n:.2f} {u}" + n /= 1024 + +def inspect_model_devices(model_path_or_id: str) -> str: + """Inspect where model parameters are placed and return detailed report""" + output = [] + + try: + output.append(f"=== Inspecting Model: {model_path_or_id} ===\n") + + # Load model as-is (don't force a map yet—show reality) + model = AutoModelForCausalLM.from_pretrained( + model_path_or_id, + torch_dtype="auto", + device_map="auto", + low_cpu_mem_usage=True + ) + + output.append(f">>> hf_device_map present: {hasattr(model, 'hf_device_map')}") + if hasattr(model, "hf_device_map"): + output.append(">>> device_map (first 20 entries):") + for i, (k, v) in enumerate(model.hf_device_map.items()): + if i < 20: + output.append(f" {k:40s} -> {v}") + if len(model.hf_device_map) > 20: + output.append(f" ... and {len(model.hf_device_map) - 20} more entries") + + totals = {} + by_dtype = {} + on_meta = [] + + for n, p in model.named_parameters(): + dev = str(p.device) + totals[dev] = totals.get(dev, 0) + p.numel() * p.element_size() + by_dtype[p.dtype] = by_dtype.get(p.dtype, 0) + p.numel() * p.element_size() + if dev == "meta": + on_meta.append(n) + + output.append("\n=== Bytes by device ===") + for dev, b in totals.items(): + output.append(f" {dev:10s} : {pretty_bytes(b)}") + + output.append("\n=== Bytes by dtype ===") + for dt, b in by_dtype.items(): + output.append(f" {str(dt):12s} : {pretty_bytes(b)}") + + if on_meta: + output.append(f"\n⚠️ WARNING: {len(on_meta)} parameters on META (not really loaded). Examples:") + for n in on_meta[:10]: + output.append(f" - {n}") + if len(on_meta) > 10: + output.append(f" ... and {len(on_meta) - 10} more") + + if torch.cuda.is_available(): + free, total = torch.cuda.mem_get_info() + used = total - free + output.append(f"\n=== CUDA Memory ===") + output.append(f" Used: {pretty_bytes(used)} / Total: {pretty_bytes(total)} on cuda:0") + output.append(f" Free: {pretty_bytes(free)} ({(free/total)*100:.1f}%)") + else: + output.append("\n❌ CUDA not available.") + + # Quick check if fully on GPU + all_cuda = all(str(p.device).startswith("cuda") for _, p in model.named_parameters()) + no_meta = not any(str(p.device) == "meta" for _, p in model.named_parameters()) + + output.append(f"\n=== Summary ===") + if all_cuda and no_meta: + output.append("✅ All parameters are on CUDA") + else: + output.append("❌ Model is NOT fully on GPU") + if on_meta: + output.append(" - Some parameters are on META device") + if not all_cuda: + output.append(" - Some parameters are on CPU") + + # Clean up model to free memory + del model + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + except Exception as e: + output.append(f"❌ Error inspecting model: {str(e)}") + + return "\n".join(output) + +def inspect_loaded_model(model) -> str: + """Inspect an already loaded model""" + output = [] + + try: + output.append("=== Inspecting Currently Loaded Model ===\n") + + totals = {} + by_dtype = {} + on_meta = [] + + for n, p in model.named_parameters(): + dev = str(p.device) + totals[dev] = totals.get(dev, 0) + p.numel() * p.element_size() + by_dtype[p.dtype] = by_dtype.get(p.dtype, 0) + p.numel() * p.element_size() + if dev == "meta": + on_meta.append(n) + + output.append("=== Bytes by device ===") + for dev, b in totals.items(): + output.append(f" {dev:10s} : {pretty_bytes(b)}") + + output.append("\n=== Bytes by dtype ===") + for dt, b in by_dtype.items(): + output.append(f" {str(dt):12s} : {pretty_bytes(b)}") + + if on_meta: + output.append(f"\n⚠️ WARNING: {len(on_meta)} parameters on META. Examples:") + for n in on_meta[:5]: + output.append(f" - {n}") + + if torch.cuda.is_available(): + free, total = torch.cuda.mem_get_info() + used = total - free + output.append(f"\n=== CUDA Memory ===") + output.append(f" Used: {pretty_bytes(used)} / Total: {pretty_bytes(total)}") + output.append(f" Free: {pretty_bytes(free)} ({(free/total)*100:.1f}%)") + + # Quick check + all_cuda = all(str(p.device).startswith("cuda") for _, p in model.named_parameters()) + no_meta = not any(str(p.device) == "meta" for _, p in model.named_parameters()) + + output.append(f"\n=== Summary ===") + if all_cuda and no_meta: + output.append("✅ All parameters are on CUDA") + else: + output.append("❌ Model is NOT fully on GPU") + + except Exception as e: + output.append(f"❌ Error: {str(e)}") + + return "\n".join(output) + +if __name__ == "__main__": + import sys + if len(sys.argv) > 1: + model_path = sys.argv[1] + result = inspect_model_devices(model_path) + print(result) + else: + print("Usage: python inspect_devices.py ") \ No newline at end of file diff --git a/tools/test_gptq.py b/tools/test_gptq.py new file mode 100644 index 0000000..331410b --- /dev/null +++ b/tools/test_gptq.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +""" +Test GPTQ model loader (requires auto-gptq package) +""" +import os +import traceback + +def test_gptq_loader(model_id="TheBloke/Llama-2-7B-Chat-GPTQ"): + """Test loading GPTQ model with integer device""" + print(f"=== Testing GPTQ Loader: {model_id} ===") + + try: + from auto_gptq import AutoGPTQForCausalLM + from transformers import AutoTokenizer + + # Load tokenizer + print("Loading tokenizer...") + tok = AutoTokenizer.from_pretrained( + model_id, + use_fast=True, + token=os.getenv("HF_TOKEN") + ) + print("✓ Tokenizer loaded") + + # Load GPTQ model with integer device + print("Loading GPTQ model...") + model = AutoGPTQForCausalLM.from_quantized( + model_id, + device=0, # integer index, not "cuda:0" + use_safetensors=True, + trust_remote_code=True, + token=os.getenv("HF_TOKEN") + ) + print("✓ GPTQ model loaded on GPU 0") + + # Test generation + prompt = "The benefits of GPU inference are" + inputs = tok(prompt, return_tensors="pt").to("cuda") + + print("Testing generation...") + outputs = model.generate( + **inputs, + max_new_tokens=16, + do_sample=False, + pad_token_id=tok.eos_token_id + ) + + result = tok.decode(outputs[0], skip_special_tokens=True) + print(f"✓ Generation test passed:") + print(f"Output: {result}") + + return True + + except ImportError as e: + print(f"✗ auto-gptq not available: {e}") + print("Install with: pip install auto-gptq") + return False + except Exception as e: + print(f"✗ Error: {type(e).__name__} - {e}") + traceback.print_exc() + return False + +if __name__ == "__main__": + import torch + print("torch:", torch.__version__) + print("cuda available:", torch.cuda.is_available()) + print() + + test_gptq_loader() \ No newline at end of file diff --git a/tools/test_meta_fp16.py b/tools/test_meta_fp16.py new file mode 100644 index 0000000..ee39cad --- /dev/null +++ b/tools/test_meta_fp16.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +""" +Minimal loader test for Meta's FP16/BF16 Llama models (no GPTQ) +""" +import os +import traceback +from transformers import AutoTokenizer, AutoModelForCausalLM + +def test_meta_fp16_loader(): + """Test loading Meta's Llama model with device_map=auto""" + MODEL_ID = "meta-llama/Meta-Llama-3.1-8B" # Meta's repo + + print(f"=== Testing Meta FP16 Loader: {MODEL_ID} ===") + + try: + # Load tokenizer + print("Loading tokenizer...") + tok = AutoTokenizer.from_pretrained( + MODEL_ID, + use_fast=True, + token=os.getenv("HF_TOKEN") + ) + print("✓ Tokenizer loaded") + + # Load model with device_map="auto" + print("Loading model...") + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype="auto", + device_map="auto", # let accelerate place it + trust_remote_code=True, + token=os.getenv("HF_TOKEN") + ) + + # Check device placement + device = next(model.parameters()).device + print(f"✓ Model loaded on device: {device}") + + # Test generation + prompt = "The benefits of GPU inference are" + inputs = tok(prompt, return_tensors="pt").to(device) + + print("Testing generation...") + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=16, + do_sample=False, + pad_token_id=tok.eos_token_id + ) + + result = tok.decode(outputs[0], skip_special_tokens=True) + print(f"✓ Generation test passed:") + print(f"Output: {result}") + + return True + + except Exception as e: + print(f"✗ Error: {type(e).__name__} - {e}") + traceback.print_exc() + return False + +if __name__ == "__main__": + import torch + print("torch:", torch.__version__) + print("cuda available:", torch.cuda.is_available()) + print() + + test_meta_fp16_loader() \ No newline at end of file