Files
dark_hal/chat_templates.py
2026-03-13 12:56:43 -07:00

393 lines
16 KiB
Python

"""
Chat Template Management System
This module provides chat template management for different model formats,
allowing users to apply proper conversation formatting for optimal model performance.
"""
import json
import os
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, asdict
import tkinter as tk
from tkinter import ttk, messagebox, simpledialog
@dataclass
class ChatTemplate:
"""Represents a chat template configuration"""
name: str
description: str
system_prefix: str = ""
system_suffix: str = ""
user_prefix: str = ""
user_suffix: str = ""
assistant_prefix: str = ""
assistant_suffix: str = ""
turn_separator: str = ""
eos_token: str = ""
bos_token: str = ""
stop_tokens: List[str] = None
add_generation_prompt: bool = True
def __post_init__(self):
if self.stop_tokens is None:
self.stop_tokens = []
class ChatTemplateManager:
"""Manages chat templates with JSON persistence"""
def __init__(self, templates_file: str = "chat_templates.json"):
self.templates_file = templates_file
self.templates: Dict[str, ChatTemplate] = {}
self._load_templates()
self._ensure_default_templates()
def _load_templates(self):
"""Load templates from JSON file"""
if os.path.exists(self.templates_file):
try:
with open(self.templates_file, 'r', encoding='utf-8') as f:
data = json.load(f)
for name, template_data in data.items():
self.templates[name] = ChatTemplate(**template_data)
except Exception as e:
print(f"Error loading chat templates: {e}")
def _save_templates(self):
"""Save templates to JSON file"""
try:
data = {name: asdict(template) for name, template in self.templates.items()}
with open(self.templates_file, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
except Exception as e:
print(f"Error saving chat templates: {e}")
def _ensure_default_templates(self):
"""Ensure default templates exist"""
if "Llama-3.1-Instruct" not in self.templates:
self.templates["Llama-3.1-Instruct"] = ChatTemplate(
name="Llama-3.1-Instruct",
description="Official Llama 3.1 Instruct chat template with proper headers and EOT tokens",
bos_token="<|begin_of_text|>",
system_prefix="<|start_header_id|>system<|end_header_id|>\n\n",
system_suffix="<|eot_id|>",
user_prefix="<|start_header_id|>user<|end_header_id|>\n\n",
user_suffix="<|eot_id|>",
assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n",
assistant_suffix="", # Model generates until <|eot_id|>
eos_token="<|eot_id|>",
stop_tokens=["<|eot_id|>"],
add_generation_prompt=True
)
self._save_templates()
def get_template_names(self) -> List[str]:
"""Get list of all template names"""
return list(self.templates.keys())
def get_template(self, name: str) -> Optional[ChatTemplate]:
"""Get a template by name"""
return self.templates.get(name)
def add_template(self, template: ChatTemplate) -> bool:
"""Add a new template"""
if template.name in self.templates:
return False # Template already exists
self.templates[template.name] = template
self._save_templates()
return True
def update_template(self, template: ChatTemplate) -> None:
"""Update an existing template"""
self.templates[template.name] = template
self._save_templates()
def delete_template(self, name: str) -> bool:
"""Delete a template by name"""
if name in self.templates:
del self.templates[name]
self._save_templates()
return True
return False
def format_conversation(self, template_name: str, messages: List[Dict[str, str]],
add_generation_prompt: bool = True) -> str:
"""Format a conversation using the specified template"""
template = self.get_template(template_name)
if not template:
# Fallback to simple User:/Assistant: format
return self._format_simple(messages, add_generation_prompt)
result = []
# Add BOS token if specified
if template.bos_token:
result.append(template.bos_token)
# System messages are handled by the calling code, so we don't need to add default ones here
for message in messages:
role = message.get("role", "user")
content = message.get("content", "")
if role == "system":
if template.system_prefix or template.system_suffix:
result.append(f"{template.system_prefix}{content}{template.system_suffix}")
else:
result.append(content)
elif role == "user":
result.append(f"{template.user_prefix}{content}{template.user_suffix}")
elif role == "assistant":
result.append(f"{template.assistant_prefix}{content}{template.assistant_suffix}")
# Add turn separator if specified (but not after the last message if we're adding generation prompt)
if template.turn_separator and not (add_generation_prompt and message == messages[-1]):
result.append(template.turn_separator)
# Add generation prompt for assistant response
if add_generation_prompt and template.add_generation_prompt:
result.append(template.assistant_prefix)
return "".join(result)
def _format_simple(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
"""Simple fallback formatting with User:/Assistant: labels"""
result = []
for message in messages:
role = message.get("role", "user")
content = message.get("content", "")
if role == "system":
result.append(f"System: {content}")
elif role == "user":
result.append(f"User: {content}")
elif role == "assistant":
result.append(f"Assistant: {content}")
if add_generation_prompt:
result.append("Assistant:")
return "\n".join(result)
def get_stop_tokens(self, template_name: str) -> List[str]:
"""Get stop tokens for a template"""
template = self.get_template(template_name)
if template and template.stop_tokens:
return template.stop_tokens
return []
class ChatTemplateDialog:
"""Dialog for creating/editing chat templates"""
def __init__(self, parent: tk.Tk, template: ChatTemplate = None):
self.parent = parent
self.template = template
self.result = None
# Create dialog
self.dialog = tk.Toplevel(parent)
self.dialog.title("Chat Template Editor")
self.dialog.geometry("600x700")
self.dialog.resizable(True, True)
# Make dialog modal
self.dialog.transient(parent)
self.dialog.grab_set()
self._build_ui()
if template:
self._load_template_data()
self._center_window()
def _build_ui(self):
"""Build the template editor UI"""
# Main frame with scrollbar
main_frame = ttk.Frame(self.dialog)
main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# Basic Info
info_frame = ttk.LabelFrame(main_frame, text="Template Information", padding=10)
info_frame.pack(fill=tk.X, pady=(0, 10))
ttk.Label(info_frame, text="Name:").grid(row=0, column=0, sticky=tk.W, pady=2)
self.name_var = tk.StringVar()
ttk.Entry(info_frame, textvariable=self.name_var, width=40).grid(row=0, column=1, sticky=tk.EW, pady=2)
ttk.Label(info_frame, text="Description:").grid(row=1, column=0, sticky=tk.W, pady=2)
self.desc_var = tk.StringVar()
ttk.Entry(info_frame, textvariable=self.desc_var, width=40).grid(row=1, column=1, sticky=tk.EW, pady=2)
info_frame.grid_columnconfigure(1, weight=1)
# Template Components
components_frame = ttk.LabelFrame(main_frame, text="Template Components", padding=10)
components_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10))
# Create entry fields for all template components
self.template_vars = {}
components = [
("BOS Token", "bos_token"),
("System Prefix", "system_prefix"),
("System Suffix", "system_suffix"),
("User Prefix", "user_prefix"),
("User Suffix", "user_suffix"),
("Assistant Prefix", "assistant_prefix"),
("Assistant Suffix", "assistant_suffix"),
("Turn Separator", "turn_separator"),
("EOS Token", "eos_token"),
]
for i, (label, key) in enumerate(components):
ttk.Label(components_frame, text=f"{label}:").grid(row=i, column=0, sticky=tk.W, pady=2)
var = tk.StringVar()
entry = ttk.Entry(components_frame, textvariable=var, width=50)
entry.grid(row=i, column=1, sticky=tk.EW, pady=2, padx=(5, 0))
self.template_vars[key] = var
# Stop tokens (text area)
ttk.Label(components_frame, text="Stop Tokens (one per line):").grid(row=len(components), column=0, sticky=tk.W, pady=2)
self.stop_tokens_text = tk.Text(components_frame, height=4, width=50)
self.stop_tokens_text.grid(row=len(components), column=1, sticky=tk.EW, pady=2, padx=(5, 0))
# Add generation prompt checkbox
self.add_gen_prompt_var = tk.BooleanVar(value=True)
ttk.Checkbutton(components_frame, text="Add generation prompt",
variable=self.add_gen_prompt_var).grid(row=len(components)+1, column=0, columnspan=2, sticky=tk.W, pady=5)
components_frame.grid_columnconfigure(1, weight=1)
# Preview
preview_frame = ttk.LabelFrame(main_frame, text="Preview", padding=10)
preview_frame.pack(fill=tk.X, pady=(0, 10))
ttk.Button(preview_frame, text="Generate Preview", command=self._generate_preview).pack(side=tk.LEFT)
self.preview_text = tk.Text(preview_frame, height=6, wrap=tk.WORD)
self.preview_text.pack(fill=tk.BOTH, expand=True, pady=(10, 0))
# Buttons
button_frame = ttk.Frame(main_frame)
button_frame.pack(fill=tk.X, pady=(10, 0))
ttk.Button(button_frame, text="Save", command=self._save_template).pack(side=tk.RIGHT, padx=5)
ttk.Button(button_frame, text="Cancel", command=self.dialog.destroy).pack(side=tk.RIGHT)
def _load_template_data(self):
"""Load existing template data into the form"""
if not self.template:
return
self.name_var.set(self.template.name)
self.desc_var.set(self.template.description)
for key, var in self.template_vars.items():
value = getattr(self.template, key, "")
var.set(value)
# Load stop tokens
if self.template.stop_tokens:
self.stop_tokens_text.insert('1.0', '\n'.join(self.template.stop_tokens))
self.add_gen_prompt_var.set(self.template.add_generation_prompt)
def _generate_preview(self):
"""Generate a preview of the template formatting"""
try:
# Create a temporary template from current form data
template = self._create_template_from_form()
# Sample conversation
sample_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you! How can I help you today?"},
{"role": "user", "content": "Can you explain quantum physics?"}
]
# Create temporary manager to format
temp_manager = ChatTemplateManager()
temp_manager.templates["preview"] = template
formatted = temp_manager.format_conversation("preview", sample_messages, True)
self.preview_text.delete('1.0', tk.END)
self.preview_text.insert('1.0', formatted)
except Exception as e:
self.preview_text.delete('1.0', tk.END)
self.preview_text.insert('1.0', f"Error generating preview: {e}")
def _create_template_from_form(self) -> ChatTemplate:
"""Create a ChatTemplate object from form data"""
# Get stop tokens
stop_tokens_text = self.stop_tokens_text.get('1.0', tk.END).strip()
stop_tokens = [token.strip() for token in stop_tokens_text.split('\n') if token.strip()]
return ChatTemplate(
name=self.name_var.get().strip(),
description=self.desc_var.get().strip(),
bos_token=self.template_vars["bos_token"].get(),
system_prefix=self.template_vars["system_prefix"].get(),
system_suffix=self.template_vars["system_suffix"].get(),
user_prefix=self.template_vars["user_prefix"].get(),
user_suffix=self.template_vars["user_suffix"].get(),
assistant_prefix=self.template_vars["assistant_prefix"].get(),
assistant_suffix=self.template_vars["assistant_suffix"].get(),
turn_separator=self.template_vars["turn_separator"].get(),
eos_token=self.template_vars["eos_token"].get(),
stop_tokens=stop_tokens,
add_generation_prompt=self.add_gen_prompt_var.get()
)
def _save_template(self):
"""Save the template"""
try:
template = self._create_template_from_form()
# Validate required fields
if not template.name:
messagebox.showerror("Error", "Template name is required")
return
self.result = template
self.dialog.destroy()
except Exception as e:
messagebox.showerror("Error", f"Error saving template: {e}")
def _center_window(self):
"""Center the dialog on the parent window"""
self.dialog.update_idletasks()
# Get parent position
parent_x = self.parent.winfo_x()
parent_y = self.parent.winfo_y()
parent_width = self.parent.winfo_width()
parent_height = self.parent.winfo_height()
# Get dialog size
dialog_width = self.dialog.winfo_width()
dialog_height = self.dialog.winfo_height()
# Calculate position
x = parent_x + (parent_width - dialog_width) // 2
y = parent_y + (parent_height - dialog_height) // 2
self.dialog.geometry(f"+{x}+{y}")
# Global template manager instance
_template_manager = None
def get_template_manager() -> ChatTemplateManager:
"""Get the global chat template manager"""
global _template_manager
if _template_manager is None:
_template_manager = ChatTemplateManager()
return _template_manager