first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

37
llm_runtime/__init__.py Normal file
View File

@@ -0,0 +1,37 @@
from .registry import load_model as _load_model
from .types import UnifiedModel, GenerateConfig
# Announce the loaded model to the global spy so other parts of the app can access it
def load_model(*args, **kwargs):
"""
Proxy to the real registry.load_model that also announces the loaded model via __spy.
Accepts arbitrary args/kwargs to remain compatible with all loaders.
"""
model = _load_model(*args, **kwargs)
try:
import __spy as spy # local import to avoid hard dependency during tooling
# Best-effort extraction of a model name from common argument patterns
model_name = (
kwargs.get("source")
or kwargs.get("model")
or (args[0] if args else None)
or getattr(model, "name", None)
or "unknown"
)
# Shallow capture of load parameters (omit non-serializable)
safe_params = {}
for k, v in kwargs.items():
try:
repr(v) # ensure it is representable
safe_params[k] = v
except Exception:
continue
spy.set_model(str(model_name), model, **safe_params)
except Exception:
# Never let announcing break model loading
pass
return model
__all__ = ["load_model", "UnifiedModel", "GenerateConfig"]

268
llm_runtime/chat_session.py Normal file
View File

@@ -0,0 +1,268 @@
"""
Chat Session Management for KV Caching
This module provides chat session management with persistent KV cache
for efficient multi-turn conversations across different model formats.
"""
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
import threading
@dataclass
class ChatSessionConfig:
"""Configuration for chat sessions"""
max_context_length: int = 4096
cache_warmup: bool = True # Pre-fill system prompt
streaming: bool = True
class ChatSession:
"""
Manages persistent KV cache state for multi-turn conversations.
Supports different model formats:
- GGUF models: Rely on llama-cpp-python's built-in caching
- Transformers models: Manual KV cache management with past_key_values
"""
def __init__(self, session_id: str, config: ChatSessionConfig = None):
self.session_id = session_id
self.config = config or ChatSessionConfig()
self.lock = threading.RLock()
# KV cache state (transformers models)
self.past_key_values: Optional[Tuple] = None
self.cached_input_ids: Optional[List[int]] = None
self.context_length: int = 0
# Conversation history
self.messages: List[Dict[str, str]] = []
self.system_prompt: Optional[str] = None
# State tracking
self.is_prefilled: bool = False
self.last_prompt_hash: Optional[str] = None
self.last_prompt: Optional[str] = None
def add_message(self, role: str, content: str) -> None:
"""Add a message to the conversation history"""
with self.lock:
if role == "system" and not self.messages:
self.system_prompt = content
self.messages.append({"role": role, "content": content})
def clear_messages(self) -> None:
"""Clear conversation history but preserve system prompt"""
with self.lock:
if self.system_prompt:
self.messages = [{"role": "system", "content": self.system_prompt}]
else:
self.messages = []
self.invalidate_cache()
def invalidate_cache(self) -> None:
"""Invalidate KV cache - forces re-prefill on next generation"""
with self.lock:
self.past_key_values = None
self.cached_input_ids = None
self.context_length = 0
self.is_prefilled = False
self.last_prompt_hash = None
self.last_prompt = None
def should_invalidate(self, new_prompt: str, tokenizer=None) -> bool:
"""Check if cache should be invalidated based on prompt changes with proper token-level validation"""
import hashlib
current_hash = hashlib.md5(new_prompt.encode()).hexdigest()
# Always invalidate if no cache exists
if self.past_key_values is None:
return True
# Always invalidate if we don't have a tokenizer for proper validation
if tokenizer is None:
return True
# Always invalidate if no previous prompt exists
if not hasattr(self, 'last_prompt') or self.last_prompt is None:
return True
# Check for exact token-level prefix match
if self._has_valid_token_prefix(new_prompt, self.last_prompt, tokenizer):
# Additional safety checks for turn boundaries
if self._violates_turn_boundaries(new_prompt, self.last_prompt):
print("[KV_CACHE] Invalidating cache: turn boundary violation detected")
return True
return False
# Invalidate if no valid prefix match
print("[KV_CACHE] Invalidating cache: no valid token prefix match")
return True
def _has_valid_token_prefix(self, new_prompt: str, old_prompt: str, tokenizer) -> bool:
"""Check if new prompt has exact token-level prefix match with cached prompt"""
try:
# Tokenize both prompts
old_tokens = tokenizer.encode(old_prompt, add_special_tokens=False)
new_tokens = tokenizer.encode(new_prompt, add_special_tokens=False)
# New prompt must be longer than or equal to old
if len(new_tokens) < len(old_tokens):
return False
# Check exact token-by-token match for the prefix
for i, (old_token, new_token) in enumerate(zip(old_tokens, new_tokens)):
if old_token != new_token:
print(f"[KV_CACHE] Token mismatch at position {i}: old={old_token}, new={new_token}")
return False
return True
except Exception as e:
print(f"[KV_CACHE] Error in token prefix validation: {e}")
return False
def _violates_turn_boundaries(self, new_prompt: str, old_prompt: str) -> bool:
"""Check if the prompt violates turn boundaries (conversation integrity)"""
# Look for end-of-turn markers that indicate conversation corruption
eot_markers = ["<|eot_id|>", "<|end_of_turn|>", "</s>", "<|endoftext|>"]
# If the old prompt ended with an EOT marker, we should start fresh
for marker in eot_markers:
if old_prompt.rstrip().endswith(marker):
# Check if new prompt is a proper continuation (should start with User: or similar)
continuation_part = new_prompt[len(old_prompt):].lstrip()
if not (continuation_part.startswith("User:") or continuation_part.startswith("Human:") or continuation_part.startswith("\nUser:") or continuation_part.startswith("\nHuman:")):
return True
# Check for conversation format corruption (duplicate roles, malformed structure)
if self._has_conversation_corruption(new_prompt):
return True
return False
def _has_conversation_corruption(self, prompt: str) -> bool:
"""Detect conversation format corruption that indicates cache should be invalidated"""
# Look for signs of corrupted conversation format
corruption_patterns = [
"Assistant: Assistant:", # Duplicate assistant labels
"User: User:", # Duplicate user labels
"Assistant: User", # Role confusion
"User: Assistant:", # Role confusion
"of of of", # Repetitive token generation (sign of corruption)
"151 of 151", # Specific corruption pattern we observed
]
for pattern in corruption_patterns:
if pattern in prompt:
print(f"[KV_CACHE] Detected conversation corruption: '{pattern}'")
return True
return False
def update_cache(self, past_key_values: Tuple, input_ids: List[int], prompt: str) -> None:
"""Update the KV cache state with turn boundary validation"""
import hashlib
with self.lock:
# Validate that we're ending on a complete turn boundary
if not self._is_complete_turn(prompt):
print("[KV_CACHE] Warning: Caching incomplete turn - this may cause issues")
self.past_key_values = past_key_values
self.cached_input_ids = input_ids
self.context_length = len(input_ids)
self.is_prefilled = True
self.last_prompt_hash = hashlib.md5(prompt.encode()).hexdigest()
self.last_prompt = prompt # Store the actual prompt for comparison
def _is_complete_turn(self, prompt: str) -> bool:
"""Check if the prompt ends on a complete turn boundary"""
# Look for proper turn endings
prompt_trimmed = prompt.rstrip()
# Should end with Assistant response, not mid-generation
valid_endings = [
"</s>",
"<|eot_id|>",
"<|end_of_turn|>",
"<|endoftext|>",
]
# Or should end with a complete sentence/response
if any(prompt_trimmed.endswith(ending) for ending in valid_endings):
return True
# For non-marked conversations, check if it looks like a complete response
# (ends with punctuation and doesn't look cut off)
if prompt_trimmed.endswith(('.', '!', '?', ':', ';')):
return True
# If it ends with "Assistant:" it's ready for generation
if prompt_trimmed.endswith("Assistant:"):
return True
return False
def get_cache_info(self) -> Dict[str, Any]:
"""Get information about current cache state"""
with self.lock:
return {
"session_id": self.session_id,
"has_cache": self.past_key_values is not None,
"context_length": self.context_length,
"is_prefilled": self.is_prefilled,
"message_count": len(self.messages),
"max_context": self.config.max_context_length
}
class ChatSessionManager:
"""Manages multiple chat sessions"""
def __init__(self):
self.sessions: Dict[str, ChatSession] = {}
self.lock = threading.RLock()
self.default_session_id = "default"
def get_session(self, session_id: str = None, config: ChatSessionConfig = None) -> ChatSession:
"""Get or create a chat session"""
if session_id is None:
session_id = self.default_session_id
with self.lock:
if session_id not in self.sessions:
self.sessions[session_id] = ChatSession(session_id, config)
return self.sessions[session_id]
def clear_session(self, session_id: str = None) -> None:
"""Clear a specific session"""
if session_id is None:
session_id = self.default_session_id
with self.lock:
if session_id in self.sessions:
self.sessions[session_id].invalidate_cache()
self.sessions[session_id].clear_messages()
def remove_session(self, session_id: str) -> None:
"""Remove a session completely"""
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
def get_all_sessions(self) -> List[str]:
"""Get list of all session IDs"""
with self.lock:
return list(self.sessions.keys())
# Global session manager instance
_session_manager = ChatSessionManager()
def get_session_manager() -> ChatSessionManager:
"""Get the global chat session manager"""
return _session_manager

View File

@@ -0,0 +1,99 @@
import torch
from typing import Union, Literal
Backend = Literal["hf", "gptq"]
DevIn = Union[None, str, int]
DevOut = Union[str, int]
def _has_mps() -> bool:
return getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available()
def _first_cuda_index() -> int | None:
return 0 if torch.cuda.is_available() and torch.cuda.device_count() > 0 else None
def normalize_device(dev: DevIn = None, *, backend: Backend = "hf") -> DevOut:
"""
Normalize a user/device string/int into what each backend expects.
Inputs accepted:
None | "auto" | "cpu" | "mps" | "disk" | "cuda" | "cuda:N" | N (int)
Returns:
backend == "hf": "cpu" | "mps" | "cuda:N"
backend == "gptq": "cpu" | "mps" | "disk" | N (int)
"""
# 1) Auto/default
if dev in (None, "auto"):
cuda0 = _first_cuda_index()
if cuda0 is not None:
return (cuda0 if backend == "gptq" else f"cuda:{cuda0}")
if _has_mps():
return "mps"
# AutoGPTQ can also run with 'disk' offload if caller wants; default to CPU here
return "cpu"
# 2) Explicit CPU/MPS/DISK
if isinstance(dev, str) and dev.lower() in {"cpu", "mps", "disk"}:
# HF does not know "disk"; treat as CPU for HF branch
return dev if backend == "gptq" or dev != "disk" else "cpu"
# 3) Explicit CUDA string
if isinstance(dev, str) and dev.lower().startswith("cuda"):
# Accept "cuda" and "cuda:N"
if dev == "cuda":
idx = _first_cuda_index()
if idx is None:
# No CUDA available; degrade to CPU/MPS appropriately
return "cpu" if backend == "hf" else "cpu"
return (idx if backend == "gptq" else f"cuda:{idx}")
# cuda:N
try:
idx = int(dev.split(":", 1)[1])
except (IndexError, ValueError):
raise ValueError(f"Bad CUDA device string: {dev!r}. Use 'cuda' or 'cuda:N'.")
return (idx if backend == "gptq" else f"cuda:{idx}")
# 4) Integer GPU index
if isinstance(dev, int):
if dev < 0:
raise ValueError(f"GPU index must be >= 0, got {dev}")
return (dev if backend == "gptq" else f"cuda:{dev}")
raise ValueError(f"Unsupported device spec for backend={backend!r}: {dev!r}")
# --- Convenience wrappers -------------------------------------------------------
def device_for_hf(dev: DevIn = None) -> str:
"""Return a device string suitable for HuggingFace (e.g., 'cuda:0', 'cpu', 'mps')."""
out = normalize_device(dev, backend="hf")
assert isinstance(out, str)
return out
def device_for_gptq(dev: DevIn = None) -> Union[int, str]:
"""Return an int GPU index or 'cpu'/'mps'/'disk' for AutoGPTQ."""
out = normalize_device(dev, backend="gptq")
assert isinstance(out, (int, str))
return out
def debug_device_placement(model, name="model"):
"""Debug helper to check where model parameters are placed"""
try:
devices = set()
for name_param, param in model.named_parameters():
devices.add(str(param.device))
print(f"[DEBUG] {name} parameters on devices: {devices}")
# Check first parameter device
first_param = next(model.parameters())
print(f"[DEBUG] {name} primary device: {first_param.device}")
return first_param.device
except Exception as e:
print(f"[DEBUG] Could not check {name} device placement: {e}")
return None
# --- Minimal self-test (run this file directly) ---------------------------------
if __name__ == "__main__":
tests = [None, "auto", "cpu", "mps", "disk", "cuda", "cuda:0", "cuda:1", 0, 1]
for t in tests:
print(f"in={t!r:7} -> hf={device_for_hf(t)!r:7} gptq={device_for_gptq(t)!r}")

View File

@@ -0,0 +1,25 @@
# loader_factory.py
from typing import Any, Tuple
from .model_router import detect_loader_type, LoaderKind
def select_loader(source: str) -> Tuple[LoaderKind, str]:
return detect_loader_type(source)
def load_model_for_gui(source: str, **kwargs: Any):
kind, reason = detect_loader_type(source)
print(f"[ROUTER] Using {kind.upper()} loader: {reason}")
print(f"[ROUTER] Source: {source}")
print(f"[ROUTER] Kwargs: {kwargs}")
if kind == "hf":
from .loaders.transformers_loader import HFTransformersLoader
loader = HFTransformersLoader()
print(f"[ROUTER] Using HF loader for: {source}")
else: # "gguf"
from .loaders.llamacpp_loader import LlamaCppLoader
loader = LlamaCppLoader()
print(f"[ROUTER] Using GGUF loader for: {source}")
# Load the model
model = loader.load(source, **kwargs)
return model, kind, reason

View File

@@ -0,0 +1,176 @@
from typing import Any, Iterator, List, Optional
import os, json
import torch
from llm_runtime.types import UnifiedModel, GenerateConfig
from llm_runtime.device_utils import device_for_gptq
def _inputs_device_from_gptq_device(dev) -> torch.device:
"""
AutoGPTQ wants: int GPU index | 'cpu' | 'mps' | 'disk'
But tokenizer tensors need a torch.device:
- int -> 'cuda:{idx}'
- 'cpu' -> 'cpu'
- 'mps' -> 'mps'
- 'disk' -> still run the forward on CUDA/CPU; safest default: 'cpu'
"""
if isinstance(dev, int):
return torch.device(f"cuda:{dev}")
if isinstance(dev, str):
if dev in ("cpu", "mps"):
return torch.device(dev)
if dev == "disk":
# inputs live on CPU; model will page as needed
return torch.device("cpu")
# Fallback
return torch.device("cpu")
class _GPTQUnified:
def __init__(self, src: str, **kwargs: Any):
print(f"[GPTQ_DEBUG] _GPTQUnified.__init__() called with src='{src}', kwargs={kwargs}")
try:
from auto_gptq import AutoGPTQForCausalLM
from transformers import AutoTokenizer
print("[GPTQ_DEBUG] Successfully imported auto_gptq and transformers")
except ImportError as e:
print(f"[GPTQ_DEBUG] Failed to import auto_gptq or transformers: {e}")
raise ImportError("auto-gptq and transformers are required for GPTQ models. Install with: pip install auto-gptq transformers")
print(f"[GPTQ_DEBUG] Loading GPTQ model from: {src}")
# Normalize device for AutoGPTQ (int or 'cpu'/'mps'/'disk')
raw_device = kwargs.get("device")
print(f"[GPTQ_DEBUG] Raw device from kwargs: {raw_device}")
self._gptq_dev = device_for_gptq(raw_device)
print(f"[GPTQ_DEBUG] Normalized device for GPTQ: {self._gptq_dev}")
self._inputs_device = _inputs_device_from_gptq_device(self._gptq_dev)
print(f"[GPTQ_DEBUG] Using device (gptq): {self._gptq_dev} | inputs will go to: {self._inputs_device}")
trust_remote = kwargs.get("trust_remote_code", True)
token = kwargs.get("token")
# Tokenizer
self.tok = AutoTokenizer.from_pretrained(
src,
use_fast=True,
trust_remote_code=trust_remote,
token=token
)
# Model
self.model = AutoGPTQForCausalLM.from_quantized(
src,
device=self._gptq_dev, # int or 'cpu'/'mps'/'disk'
trust_remote_code=trust_remote,
use_safetensors=True,
use_triton=kwargs.get("use_triton", False),
token=token
)
# Pad token safety
if self.tok.pad_token is None:
self.tok.pad_token = self.tok.eos_token
def _build_eos_ids(self, stop) -> Optional[List[int]]:
"""Encode stop strings to token IDs (take the first token of each stop string)."""
if not stop:
return None
out: List[int] = []
for s in stop:
if not s:
continue
ids = self.tok.encode(s, add_special_tokens=False)
if ids:
out.append(ids[0])
return out or None
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
from transformers import StoppingCriteria, StoppingCriteriaList
enc = self.tok(prompt, return_tensors="pt").to(self._inputs_device)
class MultiStringStop(StoppingCriteria):
def __init__(self, toks, stops):
self.toks, self.stops = toks, stops or []
def __call__(self, input_ids, scores, **_):
# Simple but effective; for high perf, implement a token-level matcher.
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
return any(s in text for s in self.stops)
do_sample = cfg.temperature is not None and cfg.temperature > 0.0
gen_kwargs = dict(
max_new_tokens=cfg.max_tokens,
do_sample=do_sample,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
pad_token_id=self.tok.pad_token_id,
)
if do_sample:
gen_kwargs["temperature"] = float(cfg.temperature)
if cfg.stop:
gen_kwargs["stopping_criteria"] = StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)])
out_ids = self.model.generate(**enc, **gen_kwargs)
# Decode only new tokens
new_tokens = out_ids[0][enc.input_ids.shape[1]:]
return self.tok.decode(new_tokens, skip_special_tokens=True)
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
import threading
from transformers import TextIteratorStreamer
enc = self.tok(prompt, return_tensors="pt").to(self._inputs_device)
streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
do_sample = cfg.temperature is not None and cfg.temperature > 0.0
gen_kwargs = dict(
max_new_tokens=cfg.max_tokens,
do_sample=do_sample,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
streamer=streamer,
pad_token_id=self.tok.pad_token_id,
)
if do_sample:
gen_kwargs["temperature"] = float(cfg.temperature)
def _worker():
self.model.generate(**enc, **gen_kwargs)
t = threading.Thread(target=_worker, daemon=True)
t.start()
for chunk in streamer:
yield chunk
def tokenize(self, text: str) -> List[int]:
return self.tok.encode(text, add_special_tokens=False)
def detokenize(self, ids: List[int]) -> str:
return self.tok.decode(ids, skip_special_tokens=True)
class AutoGPTQLoader:
name = "gptq"
def can_load(self, source: str, **kwargs: Any) -> bool:
# Local folder?
if os.path.isdir(source):
# quantize_config.json is a strong GPTQ signal
qc = os.path.join(source, "quantize_config.json")
if os.path.exists(qc):
return True
# Fallback: peek at config.json
try:
with open(os.path.join(source, "config.json"), "r", encoding="utf-8") as f:
cfg = json.load(f)
text = json.dumps(cfg).lower()
return ("gptq" in text) or ("quantize_config" in text) or ("quant_method" in text)
except Exception:
pass
# HF repo path (heuristic): let the loader try
elif "/" in source and not os.path.exists(source):
return True
return False
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _GPTQUnified(source, **kwargs)

View File

@@ -0,0 +1,132 @@
from typing import Any, Iterator, List, Optional
import os
from llm_runtime.types import UnifiedModel, GenerateConfig
class _AWQUnified:
def __init__(self, src: str, **kwargs: Any):
try:
from autoawq import AutoAWQForCausalLM
from transformers import AutoTokenizer
except ImportError:
raise ImportError("autoawq and transformers are required for AWQ models. Install with: pip install autoawq transformers")
print(f"Loading AWQ model from: {src}")
# Load tokenizer
self.tok = AutoTokenizer.from_pretrained(
src,
use_fast=True,
trust_remote_code=kwargs.get("trust_remote_code", True)
)
# Load AWQ model
self.model = AutoAWQForCausalLM.from_quantized(
src,
device_map=kwargs.get("device_map", "auto"),
trust_remote_code=kwargs.get("trust_remote_code", True),
safetensors=True,
)
# Set pad token if not present
if self.tok.pad_token is None:
self.tok.pad_token = self.tok.eos_token
def _build_eos_ids(self, stop) -> Optional[List[int]]:
"""Convert stop strings to token IDs"""
if not stop:
return None
ids = []
for s in stop:
if len(s) == 1:
tid = self.tok.convert_tokens_to_ids(s)
if tid is not None:
ids.append(tid)
return ids or None
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
from transformers import StoppingCriteria, StoppingCriteriaList
inputs = self.tok(prompt, return_tensors="pt").to(self.model.device)
class MultiStringStop(StoppingCriteria):
def __init__(self, toks, stops):
self.toks, self.stops = toks, stops or []
def __call__(self, input_ids, scores, **_):
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
return any(s in text for s in self.stops)
out_ids = self.model.generate(
**inputs,
max_new_tokens=cfg.max_tokens,
do_sample=cfg.temperature > 0.0,
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None,
pad_token_id=self.tok.pad_token_id,
)
# Decode only the new tokens
generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return generated_text
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
import threading
from transformers import TextIteratorStreamer
enc = self.tok(prompt, return_tensors="pt").to(self.model.device)
streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
def _worker():
self.model.generate(
**enc,
max_new_tokens=cfg.max_tokens,
do_sample=cfg.temperature > 0.0,
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
streamer=streamer,
pad_token_id=self.tok.pad_token_id,
)
t = threading.Thread(target=_worker)
t.start()
for text in streamer:
yield text
def tokenize(self, text: str) -> List[int]:
return self.tok.encode(text, add_special_tokens=False)
def detokenize(self, ids: List[int]) -> str:
return self.tok.decode(ids, skip_special_tokens=True)
class AWQLoader:
name = "awq"
def can_load(self, source: str, **kwargs: Any) -> bool:
# Check for AWQ indicators
if os.path.isdir(source):
# Look for AWQ indicators in config.json
try:
import json
config_path = os.path.join(source, "config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
config_str = str(config).lower()
return ("awq" in config_str or
"quantization_config" in config and
"awq" in str(config.get("quantization_config", {})).lower())
except:
pass
elif "/" in source and not os.path.exists(source): # HF repo
# For HF repos, we'll let it try if it looks like AWQ
return "awq" in source.lower()
return False
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _AWQUnified(source, **kwargs)

View File

@@ -0,0 +1,137 @@
from typing import Any, Iterator, List, Optional
import os
from llm_runtime.types import UnifiedModel, GenerateConfig
class _ExLlama2Unified:
def __init__(self, src: str, **kwargs: Any):
try:
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer, ExLlamaV2Cache
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
except ImportError:
raise ImportError("exllamav2 is required for EXL2 models. Install with: pip install exllamav2")
print(f"Loading EXL2 model from: {src}")
# Configure model
self.config = ExLlamaV2Config(src)
# Apply any config overrides
if "max_seq_len" in kwargs:
self.config.max_seq_len = kwargs["max_seq_len"]
if "scale_pos_emb" in kwargs:
self.config.scale_pos_emb = kwargs["scale_pos_emb"]
if "scale_alpha_value" in kwargs:
self.config.scale_alpha_value = kwargs["scale_alpha_value"]
# Initialize model
self.model = ExLlamaV2(self.config)
# Load model weights
self.model.load()
# Initialize tokenizer
self.tokenizer = ExLlamaV2Tokenizer(self.config)
# Initialize cache
self.cache = ExLlamaV2Cache(self.model, lazy=kwargs.get("lazy_cache", True))
# Initialize generator for streaming
self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer)
print(f"EXL2 model loaded successfully")
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
# Create sampler settings
settings = ExLlamaV2Sampler.Settings()
settings.temperature = cfg.temperature
settings.top_p = cfg.top_p
# Set stop conditions
if cfg.stop:
# ExLlamaV2 expects stop strings as a list
stop_conditions = list(cfg.stop)
else:
stop_conditions = []
# Generate text
output = self.generator.generate_simple(
prompt=prompt,
max_new_tokens=cfg.max_tokens,
seed=kwargs.get("seed", -1),
token_healing=kwargs.get("token_healing", True),
temperature=cfg.temperature,
top_p=cfg.top_p,
stop_conditions=stop_conditions,
)
return output
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
# Create sampler settings
settings = ExLlamaV2Sampler.Settings()
settings.temperature = cfg.temperature
settings.top_p = cfg.top_p
# Set stop conditions
if cfg.stop:
stop_conditions = list(cfg.stop)
else:
stop_conditions = []
# Begin streaming generation
input_ids = self.tokenizer.encode(prompt)
self.generator.begin_stream(
input_ids=input_ids,
gen_settings=settings,
token_healing=kwargs.get("token_healing", True),
seed=kwargs.get("seed", -1),
)
generated_tokens = 0
while generated_tokens < cfg.max_tokens:
chunk, eos, tokens = self.generator.stream()
if chunk:
yield chunk
generated_tokens += tokens
# Check for stop conditions
if eos:
break
if cfg.stop:
# Check if any stop condition is met in the generated text so far
current_text = self.generator.sequence_str
if any(stop in current_text for stop in cfg.stop):
break
def tokenize(self, text: str) -> List[int]:
return self.tokenizer.encode(text).tolist()
def detokenize(self, ids: List[int]) -> str:
import torch
tensor_ids = torch.tensor([ids], dtype=torch.long)
return self.tokenizer.decode(tensor_ids)[0]
class ExLlama2Loader:
name = "exllama2"
def can_load(self, source: str, **kwargs: Any) -> bool:
# Check for EXL2 model directory structure
if os.path.isdir(source):
# Look for config.json and .safetensors files that indicate EXL2
config_path = os.path.join(source, "config.json")
if os.path.exists(config_path):
# Check if there are .safetensors files with EXL2 naming pattern
for file in os.listdir(source):
if file.endswith(".safetensors") and ("model" in file.lower() or "exl2" in file.lower()):
return True
elif source.lower().endswith(".exl2"):
return True
return False
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _ExLlama2Unified(source, **kwargs)

View File

@@ -0,0 +1,91 @@
from typing import Any, Iterator, List
from llm_runtime.types import UnifiedModel, GenerateConfig
class _LlamaCppUnified:
def __init__(self, model_path: str, **kwargs: Any):
# Import llama_cpp directly instead of from main.py to avoid circular imports
from llama_cpp import Llama
if not model_path.lower().endswith(".gguf"):
raise ValueError(f"Not a valid GGUF model: {model_path}")
self.model_path = model_path
self.kwargs = kwargs
self._llama = None
def _get_model(self):
"""Lazy load the model using existing implementation"""
if self._llama is None:
# Import the _get_llama function from main module to maintain compatibility
try:
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from main import _get_llama
self._llama = _get_llama(
self.model_path,
n_ctx=self.kwargs.get("n_ctx", 8192),
n_gpu_layers=self.kwargs.get("n_gpu_layers", 0),
lora_path=self.kwargs.get("lora_path"),
n_threads=self.kwargs.get("n_threads"),
)
except ImportError:
# Fallback to direct llama-cpp-python if main import fails
from llama_cpp import Llama
self._llama = Llama(
model_path=self.model_path,
n_ctx=self.kwargs.get("n_ctx", 8192),
n_gpu_layers=self.kwargs.get("n_gpu_layers", 0),
verbose=self.kwargs.get("verbose", False),
n_threads=self.kwargs.get("n_threads")
)
return self._llama
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
llama = self._get_model()
# Convert GenerateConfig to llama-cpp-python format
result = llama(
prompt,
max_tokens=cfg.max_tokens,
temperature=cfg.temperature,
top_p=cfg.top_p,
stop=list(cfg.stop) if cfg.stop else None,
echo=False
)
return result["choices"][0]["text"]
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
llama = self._get_model()
# Use streaming generation
for chunk in llama.create_completion(
prompt,
max_tokens=cfg.max_tokens,
temperature=cfg.temperature,
top_p=cfg.top_p,
stop=list(cfg.stop) if cfg.stop else None,
stream=True,
echo=False
):
text = chunk["choices"][0]["text"]
if text:
yield text
def tokenize(self, text: str) -> List[int]:
llama = self._get_model()
return llama.tokenize(text.encode("utf-8"), add_bos=False)
def detokenize(self, ids: List[int]) -> str:
llama = self._get_model()
return llama.detokenize(ids).decode("utf-8", errors="ignore")
class LlamaCppLoader:
name = "llamacpp"
def can_load(self, source: str, **kwargs: Any) -> bool:
return source.lower().endswith(".gguf")
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _LlamaCppUnified(source, **kwargs)

View File

@@ -0,0 +1,672 @@
from typing import Any, Iterator, List, Optional, Tuple, Dict
import os
import time
from llm_runtime.types import UnifiedModel, GenerateConfig
from llm_runtime.util_chat import apply_chat_template
from llm_runtime.chat_session import ChatSession, get_session_manager
from llm_runtime.device_utils import device_for_hf
class _HFUnified:
def __init__(self, src: str, **kwargs: Any):
print(f"[HF_DEBUG] _HFUnified.__init__() called with src='{src}', kwargs={kwargs}")
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
print("[HF_DEBUG] Successfully imported torch and transformers")
except ImportError:
print("[HF_DEBUG] Failed to import torch or transformers")
raise ImportError("transformers, torch, and accelerate are required for HF models. Install with: pip install transformers torch accelerate safetensors")
self.torch = torch
self.TextIteratorStreamer = TextIteratorStreamer
# Normalize device for HuggingFace
self.device = device_for_hf(kwargs.get("device"))
print(f"Loading HF model from: {src}")
print(f"Using device: {self.device}")
# Load tokenizer
self.tok = AutoTokenizer.from_pretrained(
src,
use_fast=True,
trust_remote_code=kwargs.get("trust_remote_code", True),
token=kwargs.get("token")
)
# Prepare quantization config if requested
quantization_config = self._prepare_quantization_config(kwargs)
# Prepare device mapping
device_map, max_memory = self._prepare_device_mapping(kwargs, self.device)
# Load model with advanced options
load_kwargs = {
"torch_dtype": kwargs.get("torch_dtype", "auto"),
"device_map": device_map,
"trust_remote_code": kwargs.get("trust_remote_code", True),
"low_cpu_mem_usage": True,
"token": kwargs.get("token")
}
if quantization_config:
load_kwargs["quantization_config"] = quantization_config
print(f"Using quantization: {kwargs.get('quantization', 'none')}")
# When quantization is enabled, force device_map to "auto" to avoid device format conflicts
load_kwargs["device_map"] = "auto"
print(f"[DEBUG] Forcing device_map='auto' for quantization compatibility")
if max_memory:
load_kwargs["max_memory"] = max_memory
print(f"Memory limits: {max_memory}")
if kwargs.get("offload_folder"):
load_kwargs["offload_folder"] = kwargs.get("offload_folder")
print(f"Offloading to: {kwargs.get('offload_folder')}")
self.model = AutoModelForCausalLM.from_pretrained(src, **load_kwargs)
# Set pad token if not present
if self.tok.pad_token is None:
self.tok.pad_token = self.tok.eos_token
# Initialize chat session support with context from kwargs
self.session_manager = get_session_manager()
self.current_session = None
# Store context size for session creation
self.n_ctx = kwargs.get('n_ctx', 4096)
def _prepare_quantization_config(self, kwargs: Any):
"""Prepare quantization configuration"""
quantization = kwargs.get("quantization", "none")
if quantization == "none":
return None
try:
from transformers import BitsAndBytesConfig
except ImportError:
print("Warning: BitsAndBytesConfig not available, skipping quantization")
return None
if quantization == "4bit":
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=self.torch.bfloat16,
)
elif quantization == "8bit":
return BitsAndBytesConfig(
load_in_8bit=True,
)
else:
print(f"Warning: Unknown quantization type '{quantization}', skipping")
return None
def _prepare_device_mapping(self, kwargs: Any, hf_device: str):
"""Prepare device mapping and memory limits"""
device_strategy = kwargs.get("device_strategy", "auto")
gpu_memory_limit = kwargs.get("gpu_memory_limit", None)
device_map = "auto"
max_memory = None
if device_strategy == "force_gpu":
device_map = {"": hf_device}
elif device_strategy == "balanced_split":
# Use memory limits for balanced CPU/GPU split
if gpu_memory_limit:
# For quantization, use integer device index; for non-quantization, use string
if kwargs.get("quantization", "none") != "none":
gpu_device = 0 if hf_device.startswith("cuda:") else 0
else:
gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0"
max_memory = {
gpu_device: f"{gpu_memory_limit}GiB",
"cpu": "48GiB", # Large CPU limit
}
device_map = "auto"
elif device_strategy == "cpu_only":
device_map = {"": "cpu"}
else: # auto
if gpu_memory_limit:
# For quantization, use integer device index; for non-quantization, use string
if kwargs.get("quantization", "none") != "none":
gpu_device = 0 if hf_device.startswith("cuda:") else 0
else:
gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0"
max_memory = {
gpu_device: f"{gpu_memory_limit}GiB",
"cpu": "32GiB",
}
device_map = "auto"
return device_map, max_memory
def _build_eos_ids(self, stop) -> Optional[List[int]]:
"""Convert stop strings to token IDs"""
if not stop:
return None
ids = []
for s in stop:
# Handle single character stops
if len(s) == 1:
tid = self.tok.convert_tokens_to_ids(s)
if tid is not None:
ids.append(tid)
return ids or None
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
"""Generate text - uses KV cache if session_id provided, otherwise fallback to standard generation"""
# Set defaults for None values
if cfg.temperature is None:
cfg.temperature = 0.7 # Default temperature
if cfg.top_p is None:
cfg.top_p = 0.95 # Default top_p
if cfg.max_tokens is None:
cfg.max_tokens = 500 # Default max tokens
session_id = kwargs.get('session_id', 'default') # Use 'default' session if none specified
if session_id:
return self.generate_with_cache(prompt, session_id, cfg, **kwargs)
# Fallback to standard generation without KV cache
from transformers import StoppingCriteria, StoppingCriteriaList
inputs = self.tok(prompt, return_tensors="pt").to(self.model.device)
class MultiStringStop(StoppingCriteria):
def __init__(self, toks, stops):
self.toks, self.stops = toks, stops or []
def __call__(self, input_ids, scores, **_):
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
return any(s in text for s in self.stops)
with self.torch.no_grad():
out_ids = self.model.generate(
**inputs,
max_new_tokens=cfg.max_tokens,
do_sample=cfg.temperature is not None and cfg.temperature > 0.0,
temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None,
pad_token_id=self.tok.pad_token_id,
)
# Decode only the new tokens
generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return generated_text
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
"""Stream text generation - uses KV cache if session_id provided, otherwise fallback to standard streaming"""
session_id = kwargs.pop('session_id', 'default') # Remove from kwargs to avoid duplicate
if session_id:
yield from self.stream_with_cache(prompt, session_id, cfg, **kwargs)
return
# Fallback to standard streaming without KV cache
import threading
enc = self.tok(prompt, return_tensors="pt").to(self.model.device)
streamer = self.TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
def _worker():
with self.torch.no_grad():
self.model.generate(
**enc,
max_new_tokens=cfg.max_tokens,
do_sample=cfg.temperature is not None and cfg.temperature > 0.0,
temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
streamer=streamer,
pad_token_id=self.tok.pad_token_id,
)
t = threading.Thread(target=_worker)
t.start()
for text in streamer:
yield text
def tokenize(self, text: str) -> List[int]:
return self.tok.encode(text, add_special_tokens=False)
def detokenize(self, ids: List[int]) -> str:
return self.tok.decode(ids, skip_special_tokens=True)
def get_session(self, session_id: str = "default") -> ChatSession:
"""Get or create a chat session for KV caching"""
# Create session config with the correct context size
from llm_runtime.chat_session import ChatSessionConfig
session_config = ChatSessionConfig(max_context_length=self.n_ctx)
return self.session_manager.get_session(session_id, config=session_config)
def _prefill_phase(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
"""Prefill phase: process the full prompt and return logits + past_key_values"""
with self.torch.no_grad():
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=True,
return_dict=True
)
return outputs.logits, outputs.past_key_values
def _prefill_incremental(self, new_input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
"""Incremental prefill: process only new tokens with existing KV cache"""
with self.torch.no_grad():
outputs = self.model(
input_ids=new_input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
return outputs.logits, outputs.past_key_values
def _decode_step(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
"""Single decode step: generate next token with KV cache"""
with self.torch.no_grad():
outputs = self.model(
input_ids=input_ids, # Should be shape [1, 1] for single token
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
return outputs.logits, outputs.past_key_values
def _sample_token(self, logits: 'torch.Tensor', cfg: GenerateConfig) -> int:
"""Sample next token from logits based on generation config with optimized sampling"""
# Get logits for the last token
next_token_logits = logits[0, -1, :]
# Apply temperature scaling
if cfg.temperature is not None and cfg.temperature != 1.0 and cfg.temperature > 0:
next_token_logits = next_token_logits / cfg.temperature
# Apply top-p (nucleus) sampling if specified
if cfg.temperature is not None and cfg.temperature > 0.0 and cfg.top_p is not None and cfg.top_p < 1.0:
sorted_logits, sorted_indices = self.torch.sort(next_token_logits, descending=True)
cumulative_probs = self.torch.cumsum(self.torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > cfg.top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
if cfg.temperature is not None and cfg.temperature > 0.0:
# Sample from distribution
probs = self.torch.softmax(next_token_logits, dim=-1)
next_token = self.torch.multinomial(probs, num_samples=1)
else:
# Greedy sampling (deterministic)
next_token = self.torch.argmax(next_token_logits, dim=-1, keepdim=True)
return next_token.item()
def _should_stop(self, token_id: int, generated_text: str, cfg: GenerateConfig) -> bool:
"""Check if generation should stop"""
# Check for EOS token
if token_id == self.tok.eos_token_id:
return True
# Check for custom stop strings
if cfg.stop:
for stop_str in cfg.stop:
if stop_str in generated_text:
return True
return False
def generate_with_cache(self, prompt: str, session_id: str = "default",
cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
"""Generate text with persistent KV cache using manual generation loop"""
# Set defaults for None values
if cfg.temperature is None:
cfg.temperature = 0.7 # Default temperature
if cfg.top_p is None:
cfg.top_p = 0.95 # Default top_p
if cfg.max_tokens is None:
cfg.max_tokens = 500 # Default max tokens
session = self.get_session(session_id)
# Tokenize input
inputs = self.tok(prompt, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(self.model.device)
attention_mask = inputs.attention_mask.to(self.model.device)
generated_tokens = []
max_new_tokens = cfg.max_tokens
# Check if we need to invalidate cache with proper token validation
if session.should_invalidate(prompt, self.tok):
session.invalidate_cache()
print(f"[KV_CACHE] Cache invalidated for session {session_id}")
# Determine if we need prefill phase
if session.past_key_values is None:
print(f"[KV_CACHE] Running prefill phase for {len(input_ids[0])} tokens")
# Prefill phase: process full prompt
logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg)
# Update session cache
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
# Sample first token
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
# Update input_ids and attention_mask for decode phase
current_length = input_ids.shape[1]
else:
print(f"[KV_CACHE] Using cached KV state, context length: {session.context_length}")
# Use cached state
past_key_values = session.past_key_values
current_length = session.context_length
# For cached state, we still need to process the new part of the prompt if it exists
cached_length = len(session.cached_input_ids)
if input_ids.shape[1] > cached_length:
print(f"[KV_CACHE] Processing {input_ids.shape[1] - cached_length} new tokens (incremental prefill)")
new_tokens = input_ids[:, cached_length:]
# Process only the new tokens with existing KV cache
# Create attention mask that covers both cached and new tokens
total_length = cached_length + new_tokens.shape[1]
extended_attention = self.torch.ones((1, total_length), device=self.model.device)
logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg)
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
current_length = input_ids.shape[1]
else:
# No new tokens to process - get initial logits for generation
# Use a forward pass with the last token to get proper logits distribution
last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device)
total_length = current_length + 1
next_attention = self.torch.ones((1, total_length), device=self.model.device)
logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg)
# Sample first new token
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
# Decode phase: generate tokens one by one with optimized attention management
for step in range(max_new_tokens - 1): # -1 because we already generated first token
# Check stop conditions
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
if self._should_stop(next_token_id, generated_text, cfg):
break
# Prepare inputs for next step - only pass the new token
next_input = self.torch.tensor([[next_token_id]], device=self.model.device)
total_length = current_length + len(generated_tokens) + 1
next_attention = self.torch.ones((1, total_length), device=self.model.device)
# Generate next token using cached KV states
logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg)
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
# Update session with final state - include all processed tokens
final_input_ids = input_ids[0].tolist() + generated_tokens
session.update_cache(past_key_values, final_input_ids, prompt + self.tok.decode(generated_tokens, skip_special_tokens=True))
# Decode generated tokens
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
print(f"[KV_CACHE] Generated {len(generated_tokens)} tokens with KV cache")
return generated_text
def stream_with_cache(self, prompt: str, session_id: str = "default",
cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
"""Stream generation with persistent KV cache"""
session = self.get_session(session_id)
# Tokenize input
inputs = self.tok(prompt, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(self.model.device)
attention_mask = inputs.attention_mask.to(self.model.device)
generated_tokens = []
max_new_tokens = cfg.max_tokens
# Check if we need to invalidate cache with proper token validation
if session.should_invalidate(prompt, self.tok):
session.invalidate_cache()
print(f"[KV_CACHE] Cache invalidated for session {session_id}")
# Determine if we need prefill phase
if session.past_key_values is None:
print(f"[KV_CACHE] Streaming prefill phase for {len(input_ids[0])} tokens")
# Prefill phase: process full prompt
logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg)
# Update session cache
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
# Sample first token
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
current_length = input_ids.shape[1]
# Yield first token
first_text = self.tok.decode([next_token_id], skip_special_tokens=True)
if first_text:
yield first_text
else:
print(f"[KV_CACHE] Streaming with cached KV state, context length: {session.context_length}")
# Use cached state - optimized logic similar to generate_with_cache
past_key_values = session.past_key_values
current_length = session.context_length
# Handle new tokens in prompt with incremental prefill
cached_length = len(session.cached_input_ids)
if input_ids.shape[1] > cached_length:
print(f"[KV_CACHE] Streaming incremental prefill for {input_ids.shape[1] - cached_length} new tokens")
new_tokens = input_ids[:, cached_length:]
# Process only new tokens with existing KV cache
total_length = cached_length + new_tokens.shape[1]
extended_attention = self.torch.ones((1, total_length), device=self.model.device)
logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg)
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
current_length = input_ids.shape[1]
else:
# No new tokens - use last cached token for initial generation
last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device)
total_length = current_length + 1
next_attention = self.torch.ones((1, total_length), device=self.model.device)
logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg)
# Sample first token
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
# Yield first token
first_text = self.tok.decode([next_token_id], skip_special_tokens=True)
if first_text:
yield first_text
# Decode phase: stream tokens one by one with optimized KV cache usage
for step in range(max_new_tokens - 1):
# Check stop conditions
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
if self._should_stop(next_token_id, generated_text, cfg):
break
# Prepare inputs for next step - efficient single token processing
next_input = self.torch.tensor([[next_token_id]], device=self.model.device)
total_length = current_length + len(generated_tokens) + 1
next_attention = self.torch.ones((1, total_length), device=self.model.device)
# Generate next token using cached KV states
logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg)
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
# Yield the new token immediately
token_text = self.tok.decode([next_token_id], skip_special_tokens=True)
if token_text:
yield token_text
# Update session with final state - include all processed tokens
final_input_ids = input_ids[0].tolist() + generated_tokens
full_generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
session.update_cache(past_key_values, final_input_ids, prompt + full_generated_text)
print(f"[KV_CACHE] Streamed {len(generated_tokens)} tokens with persistent KV cache")
def clear_session_cache(self, session_id: str = "default") -> None:
"""Clear KV cache for a specific session"""
session = self.get_session(session_id)
session.invalidate_cache()
print(f"[DEBUG] Cleared cache for session {session_id}")
def get_session_info(self, session_id: str = "default") -> Dict[str, Any]:
"""Get information about a chat session"""
session = self.get_session(session_id)
return session.get_cache_info()
def add_conversation_turn(self, user_message: str, assistant_message: str,
session_id: str = "default") -> None:
"""Add a complete conversation turn to the session history"""
session = self.get_session(session_id)
session.add_message("user", user_message)
session.add_message("assistant", assistant_message)
def get_model_info(self) -> Dict[str, Any]:
"""Get comprehensive information about the loaded model"""
try:
# Basic model info
info = {
"model_name": getattr(self.model.config, 'name_or_path', 'unknown'),
"model_type": self.model.config.model_type,
"vocab_size": self.model.config.vocab_size,
"device": str(self.device),
"dtype": str(self.model.dtype),
"supports_kv_cache": True,
"max_position_embeddings": getattr(self.model.config, 'max_position_embeddings', 'unknown'),
"torch_compile_enabled": hasattr(self.model, '_orig_mod')
}
# Memory info
if self.torch.cuda.is_available() and str(self.device) != 'cpu':
info.update({
"gpu_memory_allocated": f"{self.torch.cuda.memory_allocated() / 1024**3:.2f} GB",
"gpu_memory_reserved": f"{self.torch.cuda.memory_reserved() / 1024**3:.2f} GB",
"gpu_memory_total": f"{self.torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB"
})
# Model parameters
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
info.update({
"total_parameters": f"{total_params:,}",
"trainable_parameters": f"{trainable_params:,}",
"model_size_mb": f"{total_params * 4 / 1024**2:.2f} MB" # Assuming float32
})
return info
except Exception as e:
return {"error": f"Could not get model info: {e}", "supports_kv_cache": True}
def get_kv_cache_stats(self) -> Dict[str, Any]:
"""Get KV cache statistics across all sessions"""
try:
sessions = self.session_manager.get_all_sessions()
stats = {
"total_sessions": len(sessions),
"active_sessions": 0,
"total_cached_tokens": 0,
"memory_usage_estimate": "0 MB"
}
for session_id in sessions:
session = self.session_manager.get_session(session_id)
if session.past_key_values is not None:
stats["active_sessions"] += 1
stats["total_cached_tokens"] += session.context_length
# Rough estimate of KV cache memory usage
# Each token in KV cache roughly uses: hidden_size * num_layers * 2 (key + value) * 4 bytes (float32)
if stats["total_cached_tokens"] > 0:
try:
hidden_size = self.model.config.hidden_size
num_layers = self.model.config.num_hidden_layers
memory_bytes = stats["total_cached_tokens"] * hidden_size * num_layers * 2 * 4
stats["memory_usage_estimate"] = f"{memory_bytes / 1024**2:.2f} MB"
except:
pass
return stats
except Exception as e:
return {"error": f"Could not get KV cache stats: {e}"}
def warm_up_model(self, test_prompt: str = "Hello") -> Dict[str, float]:
"""Warm up the model and measure performance metrics"""
try:
print("[KV_CACHE] Warming up model...")
start_time = time.time()
# Simple generation to warm up CUDA kernels
cfg = GenerateConfig(max_tokens=5, temperature=0.0)
_ = self.generate(test_prompt, cfg=cfg, session_id="warmup")
warmup_time = time.time() - start_time
# Clean up warmup session
self.clear_session_cache("warmup")
return {
"warmup_time": warmup_time,
"status": "success"
}
except Exception as e:
return {
"warmup_time": 0.0,
"status": "failed",
"error": str(e)
}
class HFTransformersLoader:
name = "hf"
def can_load(self, source: str, **kwargs: Any) -> bool:
# Accept HF repo-id or local dir with config.json (covers .safetensors)
if "/" in source and not os.path.exists(source): # repo-id like "microsoft/DialoGPT-medium"
return True
# Check if it's a directory with config.json
if os.path.isdir(source) and os.path.exists(os.path.join(source, "config.json")):
return True
# If it's a single .safetensors file, look for config.json in the same directory
if source.lower().endswith('.safetensors'):
parent_dir = os.path.dirname(source)
return os.path.exists(os.path.join(parent_dir, "config.json"))
# Also support .bin files (PyTorch checkpoints)
if source.lower().endswith('.bin'):
parent_dir = os.path.dirname(source)
return os.path.exists(os.path.join(parent_dir, "config.json"))
return False
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _HFUnified(source, **kwargs)

View File

@@ -0,0 +1,55 @@
from __future__ import annotations
import os, json
from pathlib import Path
from typing import Literal, Optional, Dict, Tuple, Callable
# Only support Hugging Face (Safetensors) and GGUF
LoaderKind = Literal["hf", "gguf"]
def detect_loader_type(source: str) -> Tuple[LoaderKind, str]:
"""
Decide which loader to use based on the path/repo.
Returns: (kind, reason)
kind ∈ {"hf","gguf"}
"""
print(f"[ROUTER_DEBUG] detect_loader_type() called with source: '{source}'")
p = Path(source)
low = source.lower()
# 0) Force HF for official Meta repos
if not p.exists() and low.startswith("meta-llama/"):
result = ("hf", "Official Meta repo (full-precision).")
print(f"[ROUTER_DEBUG] Meta repo detected -> {result}")
return result
# 1) Local FILE
if p.exists() and p.is_file():
if p.suffix.lower() == ".gguf":
return "gguf", "Local .gguf file."
return "hf", "Local non-.gguf file (default HF)."
# 2) Local DIR
if p.exists() and p.is_dir():
# GGUF hint: any *.gguf inside
if any(p.glob("*.gguf")):
return "gguf", "Directory contains .gguf file(s)."
# Inspect config.json only for GGUF hints; all else defaults to HF
cfg = p / "config.json"
if cfg.exists():
try:
data = json.loads(cfg.read_text(encoding="utf-8"))
text = json.dumps(data).lower()
if ("gguf" in text) or ("llama.cpp" in text) or ("ggml" in text):
return "gguf", "config.json mentions GGUF/llama.cpp."
except Exception:
pass
return "hf", "Default HF for non-GGUF directories."
# 3) Remote repo style (org/name)
if ("/" in source or source.count("\\") == 1) and not p.exists():
if any(tag in low for tag in ["gguf", "ggml", "llama.cpp"]):
return "gguf", "Repo name suggests GGUF."
return "hf", "Remote repo (default HF)."
# 4) Fallback
return "hf", "Fallback to HF."

10
llm_runtime/registry.py Normal file
View File

@@ -0,0 +1,10 @@
from typing import Any
from .types import UnifiedModel
from .loader_factory import load_model_for_gui
def load_model(source: str, **kwargs: Any) -> UnifiedModel:
"""Load model using the router-based factory system"""
print(f"[REGISTRY_DEBUG] load_model() called with source='{source}', kwargs={kwargs}")
model, kind, reason = load_model_for_gui(source, **kwargs)
print(f"[REGISTRY_DEBUG] load_model_for_gui() returned: kind='{kind}', reason='{reason}'")
return model

23
llm_runtime/types.py Normal file
View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass
from typing import Iterable, Iterator, List, Optional, Protocol, Any
@dataclass
class GenerateConfig:
# None means "use model defaults"
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
stop: Optional[Iterable[str]] = None
# Optional context window hint; models may ignore if they manage it internally
context_window: Optional[int] = None
# Optional advanced knobs; models may ignore if unsupported
min_p: Optional[float] = None
typical_p: Optional[float] = None
repetition_penalty: Optional[float] = None
class UnifiedModel(Protocol):
# Default config uses None for all tunables so the model can choose
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str: ...
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]: ...
def tokenize(self, text: str) -> List[int]: ...
def detokenize(self, ids: List[int]) -> str: ...

23
llm_runtime/util_chat.py Normal file
View File

@@ -0,0 +1,23 @@
from typing import Any, Dict, List, Optional
def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
"""Apply chat template to messages using tokenizer or fallback to ChatML format"""
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
# Fallback: simple ChatML-ish format
parts = []
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
if role == "system":
parts.append(f"<|system|>\n{content}\n")
elif role == "assistant":
parts.append(f"<|assistant|>\n{content}\n")
else:
parts.append(f"<|user|>\n{content}\n")
if add_generation_prompt:
parts.append("<|assistant|>\n")
return "".join(parts)