Files
dark_hal/llm_runtime/chat_session.py
2026-03-13 12:56:43 -07:00

268 lines
11 KiB
Python

"""
Chat Session Management for KV Caching
This module provides chat session management with persistent KV cache
for efficient multi-turn conversations across different model formats.
"""
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
import threading
@dataclass
class ChatSessionConfig:
"""Configuration for chat sessions"""
max_context_length: int = 4096
cache_warmup: bool = True # Pre-fill system prompt
streaming: bool = True
class ChatSession:
"""
Manages persistent KV cache state for multi-turn conversations.
Supports different model formats:
- GGUF models: Rely on llama-cpp-python's built-in caching
- Transformers models: Manual KV cache management with past_key_values
"""
def __init__(self, session_id: str, config: ChatSessionConfig = None):
self.session_id = session_id
self.config = config or ChatSessionConfig()
self.lock = threading.RLock()
# KV cache state (transformers models)
self.past_key_values: Optional[Tuple] = None
self.cached_input_ids: Optional[List[int]] = None
self.context_length: int = 0
# Conversation history
self.messages: List[Dict[str, str]] = []
self.system_prompt: Optional[str] = None
# State tracking
self.is_prefilled: bool = False
self.last_prompt_hash: Optional[str] = None
self.last_prompt: Optional[str] = None
def add_message(self, role: str, content: str) -> None:
"""Add a message to the conversation history"""
with self.lock:
if role == "system" and not self.messages:
self.system_prompt = content
self.messages.append({"role": role, "content": content})
def clear_messages(self) -> None:
"""Clear conversation history but preserve system prompt"""
with self.lock:
if self.system_prompt:
self.messages = [{"role": "system", "content": self.system_prompt}]
else:
self.messages = []
self.invalidate_cache()
def invalidate_cache(self) -> None:
"""Invalidate KV cache - forces re-prefill on next generation"""
with self.lock:
self.past_key_values = None
self.cached_input_ids = None
self.context_length = 0
self.is_prefilled = False
self.last_prompt_hash = None
self.last_prompt = None
def should_invalidate(self, new_prompt: str, tokenizer=None) -> bool:
"""Check if cache should be invalidated based on prompt changes with proper token-level validation"""
import hashlib
current_hash = hashlib.md5(new_prompt.encode()).hexdigest()
# Always invalidate if no cache exists
if self.past_key_values is None:
return True
# Always invalidate if we don't have a tokenizer for proper validation
if tokenizer is None:
return True
# Always invalidate if no previous prompt exists
if not hasattr(self, 'last_prompt') or self.last_prompt is None:
return True
# Check for exact token-level prefix match
if self._has_valid_token_prefix(new_prompt, self.last_prompt, tokenizer):
# Additional safety checks for turn boundaries
if self._violates_turn_boundaries(new_prompt, self.last_prompt):
print("[KV_CACHE] Invalidating cache: turn boundary violation detected")
return True
return False
# Invalidate if no valid prefix match
print("[KV_CACHE] Invalidating cache: no valid token prefix match")
return True
def _has_valid_token_prefix(self, new_prompt: str, old_prompt: str, tokenizer) -> bool:
"""Check if new prompt has exact token-level prefix match with cached prompt"""
try:
# Tokenize both prompts
old_tokens = tokenizer.encode(old_prompt, add_special_tokens=False)
new_tokens = tokenizer.encode(new_prompt, add_special_tokens=False)
# New prompt must be longer than or equal to old
if len(new_tokens) < len(old_tokens):
return False
# Check exact token-by-token match for the prefix
for i, (old_token, new_token) in enumerate(zip(old_tokens, new_tokens)):
if old_token != new_token:
print(f"[KV_CACHE] Token mismatch at position {i}: old={old_token}, new={new_token}")
return False
return True
except Exception as e:
print(f"[KV_CACHE] Error in token prefix validation: {e}")
return False
def _violates_turn_boundaries(self, new_prompt: str, old_prompt: str) -> bool:
"""Check if the prompt violates turn boundaries (conversation integrity)"""
# Look for end-of-turn markers that indicate conversation corruption
eot_markers = ["<|eot_id|>", "<|end_of_turn|>", "</s>", "<|endoftext|>"]
# If the old prompt ended with an EOT marker, we should start fresh
for marker in eot_markers:
if old_prompt.rstrip().endswith(marker):
# Check if new prompt is a proper continuation (should start with User: or similar)
continuation_part = new_prompt[len(old_prompt):].lstrip()
if not (continuation_part.startswith("User:") or continuation_part.startswith("Human:") or continuation_part.startswith("\nUser:") or continuation_part.startswith("\nHuman:")):
return True
# Check for conversation format corruption (duplicate roles, malformed structure)
if self._has_conversation_corruption(new_prompt):
return True
return False
def _has_conversation_corruption(self, prompt: str) -> bool:
"""Detect conversation format corruption that indicates cache should be invalidated"""
# Look for signs of corrupted conversation format
corruption_patterns = [
"Assistant: Assistant:", # Duplicate assistant labels
"User: User:", # Duplicate user labels
"Assistant: User", # Role confusion
"User: Assistant:", # Role confusion
"of of of", # Repetitive token generation (sign of corruption)
"151 of 151", # Specific corruption pattern we observed
]
for pattern in corruption_patterns:
if pattern in prompt:
print(f"[KV_CACHE] Detected conversation corruption: '{pattern}'")
return True
return False
def update_cache(self, past_key_values: Tuple, input_ids: List[int], prompt: str) -> None:
"""Update the KV cache state with turn boundary validation"""
import hashlib
with self.lock:
# Validate that we're ending on a complete turn boundary
if not self._is_complete_turn(prompt):
print("[KV_CACHE] Warning: Caching incomplete turn - this may cause issues")
self.past_key_values = past_key_values
self.cached_input_ids = input_ids
self.context_length = len(input_ids)
self.is_prefilled = True
self.last_prompt_hash = hashlib.md5(prompt.encode()).hexdigest()
self.last_prompt = prompt # Store the actual prompt for comparison
def _is_complete_turn(self, prompt: str) -> bool:
"""Check if the prompt ends on a complete turn boundary"""
# Look for proper turn endings
prompt_trimmed = prompt.rstrip()
# Should end with Assistant response, not mid-generation
valid_endings = [
"</s>",
"<|eot_id|>",
"<|end_of_turn|>",
"<|endoftext|>",
]
# Or should end with a complete sentence/response
if any(prompt_trimmed.endswith(ending) for ending in valid_endings):
return True
# For non-marked conversations, check if it looks like a complete response
# (ends with punctuation and doesn't look cut off)
if prompt_trimmed.endswith(('.', '!', '?', ':', ';')):
return True
# If it ends with "Assistant:" it's ready for generation
if prompt_trimmed.endswith("Assistant:"):
return True
return False
def get_cache_info(self) -> Dict[str, Any]:
"""Get information about current cache state"""
with self.lock:
return {
"session_id": self.session_id,
"has_cache": self.past_key_values is not None,
"context_length": self.context_length,
"is_prefilled": self.is_prefilled,
"message_count": len(self.messages),
"max_context": self.config.max_context_length
}
class ChatSessionManager:
"""Manages multiple chat sessions"""
def __init__(self):
self.sessions: Dict[str, ChatSession] = {}
self.lock = threading.RLock()
self.default_session_id = "default"
def get_session(self, session_id: str = None, config: ChatSessionConfig = None) -> ChatSession:
"""Get or create a chat session"""
if session_id is None:
session_id = self.default_session_id
with self.lock:
if session_id not in self.sessions:
self.sessions[session_id] = ChatSession(session_id, config)
return self.sessions[session_id]
def clear_session(self, session_id: str = None) -> None:
"""Clear a specific session"""
if session_id is None:
session_id = self.default_session_id
with self.lock:
if session_id in self.sessions:
self.sessions[session_id].invalidate_cache()
self.sessions[session_id].clear_messages()
def remove_session(self, session_id: str) -> None:
"""Remove a session completely"""
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
def get_all_sessions(self) -> List[str]:
"""Get list of all session IDs"""
with self.lock:
return list(self.sessions.keys())
# Global session manager instance
_session_manager = ChatSessionManager()
def get_session_manager() -> ChatSessionManager:
"""Get the global chat session manager"""
return _session_manager