first commit
This commit is contained in:
37
llm_runtime/__init__.py
Normal file
37
llm_runtime/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from .registry import load_model as _load_model
|
||||
from .types import UnifiedModel, GenerateConfig
|
||||
|
||||
# Announce the loaded model to the global spy so other parts of the app can access it
|
||||
def load_model(*args, **kwargs):
|
||||
"""
|
||||
Proxy to the real registry.load_model that also announces the loaded model via __spy.
|
||||
Accepts arbitrary args/kwargs to remain compatible with all loaders.
|
||||
"""
|
||||
model = _load_model(*args, **kwargs)
|
||||
try:
|
||||
import __spy as spy # local import to avoid hard dependency during tooling
|
||||
# Best-effort extraction of a model name from common argument patterns
|
||||
model_name = (
|
||||
kwargs.get("source")
|
||||
or kwargs.get("model")
|
||||
or (args[0] if args else None)
|
||||
or getattr(model, "name", None)
|
||||
or "unknown"
|
||||
)
|
||||
|
||||
# Shallow capture of load parameters (omit non-serializable)
|
||||
safe_params = {}
|
||||
for k, v in kwargs.items():
|
||||
try:
|
||||
repr(v) # ensure it is representable
|
||||
safe_params[k] = v
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
spy.set_model(str(model_name), model, **safe_params)
|
||||
except Exception:
|
||||
# Never let announcing break model loading
|
||||
pass
|
||||
return model
|
||||
|
||||
__all__ = ["load_model", "UnifiedModel", "GenerateConfig"]
|
||||
268
llm_runtime/chat_session.py
Normal file
268
llm_runtime/chat_session.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
Chat Session Management for KV Caching
|
||||
|
||||
This module provides chat session management with persistent KV cache
|
||||
for efficient multi-turn conversations across different model formats.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSessionConfig:
|
||||
"""Configuration for chat sessions"""
|
||||
max_context_length: int = 4096
|
||||
cache_warmup: bool = True # Pre-fill system prompt
|
||||
streaming: bool = True
|
||||
|
||||
|
||||
class ChatSession:
|
||||
"""
|
||||
Manages persistent KV cache state for multi-turn conversations.
|
||||
|
||||
Supports different model formats:
|
||||
- GGUF models: Rely on llama-cpp-python's built-in caching
|
||||
- Transformers models: Manual KV cache management with past_key_values
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, config: ChatSessionConfig = None):
|
||||
self.session_id = session_id
|
||||
self.config = config or ChatSessionConfig()
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# KV cache state (transformers models)
|
||||
self.past_key_values: Optional[Tuple] = None
|
||||
self.cached_input_ids: Optional[List[int]] = None
|
||||
self.context_length: int = 0
|
||||
|
||||
# Conversation history
|
||||
self.messages: List[Dict[str, str]] = []
|
||||
self.system_prompt: Optional[str] = None
|
||||
|
||||
# State tracking
|
||||
self.is_prefilled: bool = False
|
||||
self.last_prompt_hash: Optional[str] = None
|
||||
self.last_prompt: Optional[str] = None
|
||||
|
||||
def add_message(self, role: str, content: str) -> None:
|
||||
"""Add a message to the conversation history"""
|
||||
with self.lock:
|
||||
if role == "system" and not self.messages:
|
||||
self.system_prompt = content
|
||||
self.messages.append({"role": role, "content": content})
|
||||
|
||||
def clear_messages(self) -> None:
|
||||
"""Clear conversation history but preserve system prompt"""
|
||||
with self.lock:
|
||||
if self.system_prompt:
|
||||
self.messages = [{"role": "system", "content": self.system_prompt}]
|
||||
else:
|
||||
self.messages = []
|
||||
self.invalidate_cache()
|
||||
|
||||
def invalidate_cache(self) -> None:
|
||||
"""Invalidate KV cache - forces re-prefill on next generation"""
|
||||
with self.lock:
|
||||
self.past_key_values = None
|
||||
self.cached_input_ids = None
|
||||
self.context_length = 0
|
||||
self.is_prefilled = False
|
||||
self.last_prompt_hash = None
|
||||
self.last_prompt = None
|
||||
|
||||
def should_invalidate(self, new_prompt: str, tokenizer=None) -> bool:
|
||||
"""Check if cache should be invalidated based on prompt changes with proper token-level validation"""
|
||||
import hashlib
|
||||
|
||||
current_hash = hashlib.md5(new_prompt.encode()).hexdigest()
|
||||
|
||||
# Always invalidate if no cache exists
|
||||
if self.past_key_values is None:
|
||||
return True
|
||||
|
||||
# Always invalidate if we don't have a tokenizer for proper validation
|
||||
if tokenizer is None:
|
||||
return True
|
||||
|
||||
# Always invalidate if no previous prompt exists
|
||||
if not hasattr(self, 'last_prompt') or self.last_prompt is None:
|
||||
return True
|
||||
|
||||
# Check for exact token-level prefix match
|
||||
if self._has_valid_token_prefix(new_prompt, self.last_prompt, tokenizer):
|
||||
# Additional safety checks for turn boundaries
|
||||
if self._violates_turn_boundaries(new_prompt, self.last_prompt):
|
||||
print("[KV_CACHE] Invalidating cache: turn boundary violation detected")
|
||||
return True
|
||||
return False
|
||||
|
||||
# Invalidate if no valid prefix match
|
||||
print("[KV_CACHE] Invalidating cache: no valid token prefix match")
|
||||
return True
|
||||
|
||||
def _has_valid_token_prefix(self, new_prompt: str, old_prompt: str, tokenizer) -> bool:
|
||||
"""Check if new prompt has exact token-level prefix match with cached prompt"""
|
||||
try:
|
||||
# Tokenize both prompts
|
||||
old_tokens = tokenizer.encode(old_prompt, add_special_tokens=False)
|
||||
new_tokens = tokenizer.encode(new_prompt, add_special_tokens=False)
|
||||
|
||||
# New prompt must be longer than or equal to old
|
||||
if len(new_tokens) < len(old_tokens):
|
||||
return False
|
||||
|
||||
# Check exact token-by-token match for the prefix
|
||||
for i, (old_token, new_token) in enumerate(zip(old_tokens, new_tokens)):
|
||||
if old_token != new_token:
|
||||
print(f"[KV_CACHE] Token mismatch at position {i}: old={old_token}, new={new_token}")
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[KV_CACHE] Error in token prefix validation: {e}")
|
||||
return False
|
||||
|
||||
def _violates_turn_boundaries(self, new_prompt: str, old_prompt: str) -> bool:
|
||||
"""Check if the prompt violates turn boundaries (conversation integrity)"""
|
||||
# Look for end-of-turn markers that indicate conversation corruption
|
||||
eot_markers = ["<|eot_id|>", "<|end_of_turn|>", "</s>", "<|endoftext|>"]
|
||||
|
||||
# If the old prompt ended with an EOT marker, we should start fresh
|
||||
for marker in eot_markers:
|
||||
if old_prompt.rstrip().endswith(marker):
|
||||
# Check if new prompt is a proper continuation (should start with User: or similar)
|
||||
continuation_part = new_prompt[len(old_prompt):].lstrip()
|
||||
if not (continuation_part.startswith("User:") or continuation_part.startswith("Human:") or continuation_part.startswith("\nUser:") or continuation_part.startswith("\nHuman:")):
|
||||
return True
|
||||
|
||||
# Check for conversation format corruption (duplicate roles, malformed structure)
|
||||
if self._has_conversation_corruption(new_prompt):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _has_conversation_corruption(self, prompt: str) -> bool:
|
||||
"""Detect conversation format corruption that indicates cache should be invalidated"""
|
||||
# Look for signs of corrupted conversation format
|
||||
corruption_patterns = [
|
||||
"Assistant: Assistant:", # Duplicate assistant labels
|
||||
"User: User:", # Duplicate user labels
|
||||
"Assistant: User", # Role confusion
|
||||
"User: Assistant:", # Role confusion
|
||||
"of of of", # Repetitive token generation (sign of corruption)
|
||||
"151 of 151", # Specific corruption pattern we observed
|
||||
]
|
||||
|
||||
for pattern in corruption_patterns:
|
||||
if pattern in prompt:
|
||||
print(f"[KV_CACHE] Detected conversation corruption: '{pattern}'")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def update_cache(self, past_key_values: Tuple, input_ids: List[int], prompt: str) -> None:
|
||||
"""Update the KV cache state with turn boundary validation"""
|
||||
import hashlib
|
||||
|
||||
with self.lock:
|
||||
# Validate that we're ending on a complete turn boundary
|
||||
if not self._is_complete_turn(prompt):
|
||||
print("[KV_CACHE] Warning: Caching incomplete turn - this may cause issues")
|
||||
|
||||
self.past_key_values = past_key_values
|
||||
self.cached_input_ids = input_ids
|
||||
self.context_length = len(input_ids)
|
||||
self.is_prefilled = True
|
||||
self.last_prompt_hash = hashlib.md5(prompt.encode()).hexdigest()
|
||||
self.last_prompt = prompt # Store the actual prompt for comparison
|
||||
|
||||
def _is_complete_turn(self, prompt: str) -> bool:
|
||||
"""Check if the prompt ends on a complete turn boundary"""
|
||||
# Look for proper turn endings
|
||||
prompt_trimmed = prompt.rstrip()
|
||||
|
||||
# Should end with Assistant response, not mid-generation
|
||||
valid_endings = [
|
||||
"</s>",
|
||||
"<|eot_id|>",
|
||||
"<|end_of_turn|>",
|
||||
"<|endoftext|>",
|
||||
]
|
||||
|
||||
# Or should end with a complete sentence/response
|
||||
if any(prompt_trimmed.endswith(ending) for ending in valid_endings):
|
||||
return True
|
||||
|
||||
# For non-marked conversations, check if it looks like a complete response
|
||||
# (ends with punctuation and doesn't look cut off)
|
||||
if prompt_trimmed.endswith(('.', '!', '?', ':', ';')):
|
||||
return True
|
||||
|
||||
# If it ends with "Assistant:" it's ready for generation
|
||||
if prompt_trimmed.endswith("Assistant:"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_cache_info(self) -> Dict[str, Any]:
|
||||
"""Get information about current cache state"""
|
||||
with self.lock:
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"has_cache": self.past_key_values is not None,
|
||||
"context_length": self.context_length,
|
||||
"is_prefilled": self.is_prefilled,
|
||||
"message_count": len(self.messages),
|
||||
"max_context": self.config.max_context_length
|
||||
}
|
||||
|
||||
|
||||
class ChatSessionManager:
|
||||
"""Manages multiple chat sessions"""
|
||||
|
||||
def __init__(self):
|
||||
self.sessions: Dict[str, ChatSession] = {}
|
||||
self.lock = threading.RLock()
|
||||
self.default_session_id = "default"
|
||||
|
||||
def get_session(self, session_id: str = None, config: ChatSessionConfig = None) -> ChatSession:
|
||||
"""Get or create a chat session"""
|
||||
if session_id is None:
|
||||
session_id = self.default_session_id
|
||||
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = ChatSession(session_id, config)
|
||||
return self.sessions[session_id]
|
||||
|
||||
def clear_session(self, session_id: str = None) -> None:
|
||||
"""Clear a specific session"""
|
||||
if session_id is None:
|
||||
session_id = self.default_session_id
|
||||
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
self.sessions[session_id].invalidate_cache()
|
||||
self.sessions[session_id].clear_messages()
|
||||
|
||||
def remove_session(self, session_id: str) -> None:
|
||||
"""Remove a session completely"""
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
|
||||
def get_all_sessions(self) -> List[str]:
|
||||
"""Get list of all session IDs"""
|
||||
with self.lock:
|
||||
return list(self.sessions.keys())
|
||||
|
||||
|
||||
# Global session manager instance
|
||||
_session_manager = ChatSessionManager()
|
||||
|
||||
|
||||
def get_session_manager() -> ChatSessionManager:
|
||||
"""Get the global chat session manager"""
|
||||
return _session_manager
|
||||
99
llm_runtime/device_utils.py
Normal file
99
llm_runtime/device_utils.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
from typing import Union, Literal
|
||||
|
||||
Backend = Literal["hf", "gptq"]
|
||||
DevIn = Union[None, str, int]
|
||||
DevOut = Union[str, int]
|
||||
|
||||
def _has_mps() -> bool:
|
||||
return getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available()
|
||||
|
||||
def _first_cuda_index() -> int | None:
|
||||
return 0 if torch.cuda.is_available() and torch.cuda.device_count() > 0 else None
|
||||
|
||||
def normalize_device(dev: DevIn = None, *, backend: Backend = "hf") -> DevOut:
|
||||
"""
|
||||
Normalize a user/device string/int into what each backend expects.
|
||||
|
||||
Inputs accepted:
|
||||
None | "auto" | "cpu" | "mps" | "disk" | "cuda" | "cuda:N" | N (int)
|
||||
|
||||
Returns:
|
||||
backend == "hf": "cpu" | "mps" | "cuda:N"
|
||||
backend == "gptq": "cpu" | "mps" | "disk" | N (int)
|
||||
"""
|
||||
# 1) Auto/default
|
||||
if dev in (None, "auto"):
|
||||
cuda0 = _first_cuda_index()
|
||||
if cuda0 is not None:
|
||||
return (cuda0 if backend == "gptq" else f"cuda:{cuda0}")
|
||||
if _has_mps():
|
||||
return "mps"
|
||||
# AutoGPTQ can also run with 'disk' offload if caller wants; default to CPU here
|
||||
return "cpu"
|
||||
|
||||
# 2) Explicit CPU/MPS/DISK
|
||||
if isinstance(dev, str) and dev.lower() in {"cpu", "mps", "disk"}:
|
||||
# HF does not know "disk"; treat as CPU for HF branch
|
||||
return dev if backend == "gptq" or dev != "disk" else "cpu"
|
||||
|
||||
# 3) Explicit CUDA string
|
||||
if isinstance(dev, str) and dev.lower().startswith("cuda"):
|
||||
# Accept "cuda" and "cuda:N"
|
||||
if dev == "cuda":
|
||||
idx = _first_cuda_index()
|
||||
if idx is None:
|
||||
# No CUDA available; degrade to CPU/MPS appropriately
|
||||
return "cpu" if backend == "hf" else "cpu"
|
||||
return (idx if backend == "gptq" else f"cuda:{idx}")
|
||||
# cuda:N
|
||||
try:
|
||||
idx = int(dev.split(":", 1)[1])
|
||||
except (IndexError, ValueError):
|
||||
raise ValueError(f"Bad CUDA device string: {dev!r}. Use 'cuda' or 'cuda:N'.")
|
||||
return (idx if backend == "gptq" else f"cuda:{idx}")
|
||||
|
||||
# 4) Integer GPU index
|
||||
if isinstance(dev, int):
|
||||
if dev < 0:
|
||||
raise ValueError(f"GPU index must be >= 0, got {dev}")
|
||||
return (dev if backend == "gptq" else f"cuda:{dev}")
|
||||
|
||||
raise ValueError(f"Unsupported device spec for backend={backend!r}: {dev!r}")
|
||||
|
||||
# --- Convenience wrappers -------------------------------------------------------
|
||||
|
||||
def device_for_hf(dev: DevIn = None) -> str:
|
||||
"""Return a device string suitable for HuggingFace (e.g., 'cuda:0', 'cpu', 'mps')."""
|
||||
out = normalize_device(dev, backend="hf")
|
||||
assert isinstance(out, str)
|
||||
return out
|
||||
|
||||
def device_for_gptq(dev: DevIn = None) -> Union[int, str]:
|
||||
"""Return an int GPU index or 'cpu'/'mps'/'disk' for AutoGPTQ."""
|
||||
out = normalize_device(dev, backend="gptq")
|
||||
assert isinstance(out, (int, str))
|
||||
return out
|
||||
|
||||
def debug_device_placement(model, name="model"):
|
||||
"""Debug helper to check where model parameters are placed"""
|
||||
try:
|
||||
devices = set()
|
||||
for name_param, param in model.named_parameters():
|
||||
devices.add(str(param.device))
|
||||
print(f"[DEBUG] {name} parameters on devices: {devices}")
|
||||
|
||||
# Check first parameter device
|
||||
first_param = next(model.parameters())
|
||||
print(f"[DEBUG] {name} primary device: {first_param.device}")
|
||||
|
||||
return first_param.device
|
||||
except Exception as e:
|
||||
print(f"[DEBUG] Could not check {name} device placement: {e}")
|
||||
return None
|
||||
|
||||
# --- Minimal self-test (run this file directly) ---------------------------------
|
||||
if __name__ == "__main__":
|
||||
tests = [None, "auto", "cpu", "mps", "disk", "cuda", "cuda:0", "cuda:1", 0, 1]
|
||||
for t in tests:
|
||||
print(f"in={t!r:7} -> hf={device_for_hf(t)!r:7} gptq={device_for_gptq(t)!r}")
|
||||
25
llm_runtime/loader_factory.py
Normal file
25
llm_runtime/loader_factory.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# loader_factory.py
|
||||
from typing import Any, Tuple
|
||||
from .model_router import detect_loader_type, LoaderKind
|
||||
|
||||
def select_loader(source: str) -> Tuple[LoaderKind, str]:
|
||||
return detect_loader_type(source)
|
||||
|
||||
def load_model_for_gui(source: str, **kwargs: Any):
|
||||
kind, reason = detect_loader_type(source)
|
||||
print(f"[ROUTER] Using {kind.upper()} loader: {reason}")
|
||||
print(f"[ROUTER] Source: {source}")
|
||||
print(f"[ROUTER] Kwargs: {kwargs}")
|
||||
|
||||
if kind == "hf":
|
||||
from .loaders.transformers_loader import HFTransformersLoader
|
||||
loader = HFTransformersLoader()
|
||||
print(f"[ROUTER] Using HF loader for: {source}")
|
||||
else: # "gguf"
|
||||
from .loaders.llamacpp_loader import LlamaCppLoader
|
||||
loader = LlamaCppLoader()
|
||||
print(f"[ROUTER] Using GGUF loader for: {source}")
|
||||
|
||||
# Load the model
|
||||
model = loader.load(source, **kwargs)
|
||||
return model, kind, reason
|
||||
176
llm_runtime/loaders/autogptq_loader.py
Normal file
176
llm_runtime/loaders/autogptq_loader.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from typing import Any, Iterator, List, Optional
|
||||
import os, json
|
||||
import torch
|
||||
from llm_runtime.types import UnifiedModel, GenerateConfig
|
||||
from llm_runtime.device_utils import device_for_gptq
|
||||
|
||||
def _inputs_device_from_gptq_device(dev) -> torch.device:
|
||||
"""
|
||||
AutoGPTQ wants: int GPU index | 'cpu' | 'mps' | 'disk'
|
||||
But tokenizer tensors need a torch.device:
|
||||
- int -> 'cuda:{idx}'
|
||||
- 'cpu' -> 'cpu'
|
||||
- 'mps' -> 'mps'
|
||||
- 'disk' -> still run the forward on CUDA/CPU; safest default: 'cpu'
|
||||
"""
|
||||
if isinstance(dev, int):
|
||||
return torch.device(f"cuda:{dev}")
|
||||
if isinstance(dev, str):
|
||||
if dev in ("cpu", "mps"):
|
||||
return torch.device(dev)
|
||||
if dev == "disk":
|
||||
# inputs live on CPU; model will page as needed
|
||||
return torch.device("cpu")
|
||||
# Fallback
|
||||
return torch.device("cpu")
|
||||
|
||||
class _GPTQUnified:
|
||||
def __init__(self, src: str, **kwargs: Any):
|
||||
print(f"[GPTQ_DEBUG] _GPTQUnified.__init__() called with src='{src}', kwargs={kwargs}")
|
||||
try:
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
print("[GPTQ_DEBUG] Successfully imported auto_gptq and transformers")
|
||||
except ImportError as e:
|
||||
print(f"[GPTQ_DEBUG] Failed to import auto_gptq or transformers: {e}")
|
||||
raise ImportError("auto-gptq and transformers are required for GPTQ models. Install with: pip install auto-gptq transformers")
|
||||
|
||||
print(f"[GPTQ_DEBUG] Loading GPTQ model from: {src}")
|
||||
|
||||
# Normalize device for AutoGPTQ (int or 'cpu'/'mps'/'disk')
|
||||
raw_device = kwargs.get("device")
|
||||
print(f"[GPTQ_DEBUG] Raw device from kwargs: {raw_device}")
|
||||
self._gptq_dev = device_for_gptq(raw_device)
|
||||
print(f"[GPTQ_DEBUG] Normalized device for GPTQ: {self._gptq_dev}")
|
||||
self._inputs_device = _inputs_device_from_gptq_device(self._gptq_dev)
|
||||
print(f"[GPTQ_DEBUG] Using device (gptq): {self._gptq_dev} | inputs will go to: {self._inputs_device}")
|
||||
|
||||
trust_remote = kwargs.get("trust_remote_code", True)
|
||||
token = kwargs.get("token")
|
||||
|
||||
# Tokenizer
|
||||
self.tok = AutoTokenizer.from_pretrained(
|
||||
src,
|
||||
use_fast=True,
|
||||
trust_remote_code=trust_remote,
|
||||
token=token
|
||||
)
|
||||
|
||||
# Model
|
||||
self.model = AutoGPTQForCausalLM.from_quantized(
|
||||
src,
|
||||
device=self._gptq_dev, # int or 'cpu'/'mps'/'disk'
|
||||
trust_remote_code=trust_remote,
|
||||
use_safetensors=True,
|
||||
use_triton=kwargs.get("use_triton", False),
|
||||
token=token
|
||||
)
|
||||
|
||||
# Pad token safety
|
||||
if self.tok.pad_token is None:
|
||||
self.tok.pad_token = self.tok.eos_token
|
||||
|
||||
def _build_eos_ids(self, stop) -> Optional[List[int]]:
|
||||
"""Encode stop strings to token IDs (take the first token of each stop string)."""
|
||||
if not stop:
|
||||
return None
|
||||
out: List[int] = []
|
||||
for s in stop:
|
||||
if not s:
|
||||
continue
|
||||
ids = self.tok.encode(s, add_special_tokens=False)
|
||||
if ids:
|
||||
out.append(ids[0])
|
||||
return out or None
|
||||
|
||||
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
enc = self.tok(prompt, return_tensors="pt").to(self._inputs_device)
|
||||
|
||||
class MultiStringStop(StoppingCriteria):
|
||||
def __init__(self, toks, stops):
|
||||
self.toks, self.stops = toks, stops or []
|
||||
def __call__(self, input_ids, scores, **_):
|
||||
# Simple but effective; for high perf, implement a token-level matcher.
|
||||
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
|
||||
return any(s in text for s in self.stops)
|
||||
|
||||
do_sample = cfg.temperature is not None and cfg.temperature > 0.0
|
||||
gen_kwargs = dict(
|
||||
max_new_tokens=cfg.max_tokens,
|
||||
do_sample=do_sample,
|
||||
top_p=cfg.top_p,
|
||||
eos_token_id=self._build_eos_ids(cfg.stop),
|
||||
pad_token_id=self.tok.pad_token_id,
|
||||
)
|
||||
if do_sample:
|
||||
gen_kwargs["temperature"] = float(cfg.temperature)
|
||||
|
||||
if cfg.stop:
|
||||
gen_kwargs["stopping_criteria"] = StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)])
|
||||
|
||||
out_ids = self.model.generate(**enc, **gen_kwargs)
|
||||
|
||||
# Decode only new tokens
|
||||
new_tokens = out_ids[0][enc.input_ids.shape[1]:]
|
||||
return self.tok.decode(new_tokens, skip_special_tokens=True)
|
||||
|
||||
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
|
||||
import threading
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
enc = self.tok(prompt, return_tensors="pt").to(self._inputs_device)
|
||||
streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
do_sample = cfg.temperature is not None and cfg.temperature > 0.0
|
||||
gen_kwargs = dict(
|
||||
max_new_tokens=cfg.max_tokens,
|
||||
do_sample=do_sample,
|
||||
top_p=cfg.top_p,
|
||||
eos_token_id=self._build_eos_ids(cfg.stop),
|
||||
streamer=streamer,
|
||||
pad_token_id=self.tok.pad_token_id,
|
||||
)
|
||||
if do_sample:
|
||||
gen_kwargs["temperature"] = float(cfg.temperature)
|
||||
|
||||
def _worker():
|
||||
self.model.generate(**enc, **gen_kwargs)
|
||||
|
||||
t = threading.Thread(target=_worker, daemon=True)
|
||||
t.start()
|
||||
for chunk in streamer:
|
||||
yield chunk
|
||||
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
return self.tok.encode(text, add_special_tokens=False)
|
||||
|
||||
def detokenize(self, ids: List[int]) -> str:
|
||||
return self.tok.decode(ids, skip_special_tokens=True)
|
||||
|
||||
class AutoGPTQLoader:
|
||||
name = "gptq"
|
||||
|
||||
def can_load(self, source: str, **kwargs: Any) -> bool:
|
||||
# Local folder?
|
||||
if os.path.isdir(source):
|
||||
# quantize_config.json is a strong GPTQ signal
|
||||
qc = os.path.join(source, "quantize_config.json")
|
||||
if os.path.exists(qc):
|
||||
return True
|
||||
# Fallback: peek at config.json
|
||||
try:
|
||||
with open(os.path.join(source, "config.json"), "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
text = json.dumps(cfg).lower()
|
||||
return ("gptq" in text) or ("quantize_config" in text) or ("quant_method" in text)
|
||||
except Exception:
|
||||
pass
|
||||
# HF repo path (heuristic): let the loader try
|
||||
elif "/" in source and not os.path.exists(source):
|
||||
return True
|
||||
return False
|
||||
|
||||
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
|
||||
return _GPTQUnified(source, **kwargs)
|
||||
132
llm_runtime/loaders/awq_loader.py
Normal file
132
llm_runtime/loaders/awq_loader.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import Any, Iterator, List, Optional
|
||||
import os
|
||||
from llm_runtime.types import UnifiedModel, GenerateConfig
|
||||
|
||||
class _AWQUnified:
|
||||
def __init__(self, src: str, **kwargs: Any):
|
||||
try:
|
||||
from autoawq import AutoAWQForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
raise ImportError("autoawq and transformers are required for AWQ models. Install with: pip install autoawq transformers")
|
||||
|
||||
print(f"Loading AWQ model from: {src}")
|
||||
|
||||
# Load tokenizer
|
||||
self.tok = AutoTokenizer.from_pretrained(
|
||||
src,
|
||||
use_fast=True,
|
||||
trust_remote_code=kwargs.get("trust_remote_code", True)
|
||||
)
|
||||
|
||||
# Load AWQ model
|
||||
self.model = AutoAWQForCausalLM.from_quantized(
|
||||
src,
|
||||
device_map=kwargs.get("device_map", "auto"),
|
||||
trust_remote_code=kwargs.get("trust_remote_code", True),
|
||||
safetensors=True,
|
||||
)
|
||||
|
||||
# Set pad token if not present
|
||||
if self.tok.pad_token is None:
|
||||
self.tok.pad_token = self.tok.eos_token
|
||||
|
||||
def _build_eos_ids(self, stop) -> Optional[List[int]]:
|
||||
"""Convert stop strings to token IDs"""
|
||||
if not stop:
|
||||
return None
|
||||
|
||||
ids = []
|
||||
for s in stop:
|
||||
if len(s) == 1:
|
||||
tid = self.tok.convert_tokens_to_ids(s)
|
||||
if tid is not None:
|
||||
ids.append(tid)
|
||||
return ids or None
|
||||
|
||||
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
inputs = self.tok(prompt, return_tensors="pt").to(self.model.device)
|
||||
|
||||
class MultiStringStop(StoppingCriteria):
|
||||
def __init__(self, toks, stops):
|
||||
self.toks, self.stops = toks, stops or []
|
||||
|
||||
def __call__(self, input_ids, scores, **_):
|
||||
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
|
||||
return any(s in text for s in self.stops)
|
||||
|
||||
out_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=cfg.max_tokens,
|
||||
do_sample=cfg.temperature > 0.0,
|
||||
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
|
||||
top_p=cfg.top_p,
|
||||
eos_token_id=self._build_eos_ids(cfg.stop),
|
||||
stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None,
|
||||
pad_token_id=self.tok.pad_token_id,
|
||||
)
|
||||
|
||||
# Decode only the new tokens
|
||||
generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
return generated_text
|
||||
|
||||
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
|
||||
import threading
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
enc = self.tok(prompt, return_tensors="pt").to(self.model.device)
|
||||
streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
def _worker():
|
||||
self.model.generate(
|
||||
**enc,
|
||||
max_new_tokens=cfg.max_tokens,
|
||||
do_sample=cfg.temperature > 0.0,
|
||||
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
|
||||
top_p=cfg.top_p,
|
||||
eos_token_id=self._build_eos_ids(cfg.stop),
|
||||
streamer=streamer,
|
||||
pad_token_id=self.tok.pad_token_id,
|
||||
)
|
||||
|
||||
t = threading.Thread(target=_worker)
|
||||
t.start()
|
||||
|
||||
for text in streamer:
|
||||
yield text
|
||||
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
return self.tok.encode(text, add_special_tokens=False)
|
||||
|
||||
def detokenize(self, ids: List[int]) -> str:
|
||||
return self.tok.decode(ids, skip_special_tokens=True)
|
||||
|
||||
class AWQLoader:
|
||||
name = "awq"
|
||||
|
||||
def can_load(self, source: str, **kwargs: Any) -> bool:
|
||||
# Check for AWQ indicators
|
||||
if os.path.isdir(source):
|
||||
# Look for AWQ indicators in config.json
|
||||
try:
|
||||
import json
|
||||
config_path = os.path.join(source, "config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
config_str = str(config).lower()
|
||||
return ("awq" in config_str or
|
||||
"quantization_config" in config and
|
||||
"awq" in str(config.get("quantization_config", {})).lower())
|
||||
except:
|
||||
pass
|
||||
elif "/" in source and not os.path.exists(source): # HF repo
|
||||
# For HF repos, we'll let it try if it looks like AWQ
|
||||
return "awq" in source.lower()
|
||||
|
||||
return False
|
||||
|
||||
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
|
||||
return _AWQUnified(source, **kwargs)
|
||||
137
llm_runtime/loaders/exl2_loader.py
Normal file
137
llm_runtime/loaders/exl2_loader.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from typing import Any, Iterator, List, Optional
|
||||
import os
|
||||
from llm_runtime.types import UnifiedModel, GenerateConfig
|
||||
|
||||
class _ExLlama2Unified:
|
||||
def __init__(self, src: str, **kwargs: Any):
|
||||
try:
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer, ExLlamaV2Cache
|
||||
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
|
||||
except ImportError:
|
||||
raise ImportError("exllamav2 is required for EXL2 models. Install with: pip install exllamav2")
|
||||
|
||||
print(f"Loading EXL2 model from: {src}")
|
||||
|
||||
# Configure model
|
||||
self.config = ExLlamaV2Config(src)
|
||||
|
||||
# Apply any config overrides
|
||||
if "max_seq_len" in kwargs:
|
||||
self.config.max_seq_len = kwargs["max_seq_len"]
|
||||
if "scale_pos_emb" in kwargs:
|
||||
self.config.scale_pos_emb = kwargs["scale_pos_emb"]
|
||||
if "scale_alpha_value" in kwargs:
|
||||
self.config.scale_alpha_value = kwargs["scale_alpha_value"]
|
||||
|
||||
# Initialize model
|
||||
self.model = ExLlamaV2(self.config)
|
||||
|
||||
# Load model weights
|
||||
self.model.load()
|
||||
|
||||
# Initialize tokenizer
|
||||
self.tokenizer = ExLlamaV2Tokenizer(self.config)
|
||||
|
||||
# Initialize cache
|
||||
self.cache = ExLlamaV2Cache(self.model, lazy=kwargs.get("lazy_cache", True))
|
||||
|
||||
# Initialize generator for streaming
|
||||
self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer)
|
||||
|
||||
print(f"EXL2 model loaded successfully")
|
||||
|
||||
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
|
||||
# Create sampler settings
|
||||
settings = ExLlamaV2Sampler.Settings()
|
||||
settings.temperature = cfg.temperature
|
||||
settings.top_p = cfg.top_p
|
||||
|
||||
# Set stop conditions
|
||||
if cfg.stop:
|
||||
# ExLlamaV2 expects stop strings as a list
|
||||
stop_conditions = list(cfg.stop)
|
||||
else:
|
||||
stop_conditions = []
|
||||
|
||||
# Generate text
|
||||
output = self.generator.generate_simple(
|
||||
prompt=prompt,
|
||||
max_new_tokens=cfg.max_tokens,
|
||||
seed=kwargs.get("seed", -1),
|
||||
token_healing=kwargs.get("token_healing", True),
|
||||
temperature=cfg.temperature,
|
||||
top_p=cfg.top_p,
|
||||
stop_conditions=stop_conditions,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
|
||||
# Create sampler settings
|
||||
settings = ExLlamaV2Sampler.Settings()
|
||||
settings.temperature = cfg.temperature
|
||||
settings.top_p = cfg.top_p
|
||||
|
||||
# Set stop conditions
|
||||
if cfg.stop:
|
||||
stop_conditions = list(cfg.stop)
|
||||
else:
|
||||
stop_conditions = []
|
||||
|
||||
# Begin streaming generation
|
||||
input_ids = self.tokenizer.encode(prompt)
|
||||
self.generator.begin_stream(
|
||||
input_ids=input_ids,
|
||||
gen_settings=settings,
|
||||
token_healing=kwargs.get("token_healing", True),
|
||||
seed=kwargs.get("seed", -1),
|
||||
)
|
||||
|
||||
generated_tokens = 0
|
||||
|
||||
while generated_tokens < cfg.max_tokens:
|
||||
chunk, eos, tokens = self.generator.stream()
|
||||
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
generated_tokens += tokens
|
||||
|
||||
# Check for stop conditions
|
||||
if eos:
|
||||
break
|
||||
|
||||
if cfg.stop:
|
||||
# Check if any stop condition is met in the generated text so far
|
||||
current_text = self.generator.sequence_str
|
||||
if any(stop in current_text for stop in cfg.stop):
|
||||
break
|
||||
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
return self.tokenizer.encode(text).tolist()
|
||||
|
||||
def detokenize(self, ids: List[int]) -> str:
|
||||
import torch
|
||||
tensor_ids = torch.tensor([ids], dtype=torch.long)
|
||||
return self.tokenizer.decode(tensor_ids)[0]
|
||||
|
||||
class ExLlama2Loader:
|
||||
name = "exllama2"
|
||||
|
||||
def can_load(self, source: str, **kwargs: Any) -> bool:
|
||||
# Check for EXL2 model directory structure
|
||||
if os.path.isdir(source):
|
||||
# Look for config.json and .safetensors files that indicate EXL2
|
||||
config_path = os.path.join(source, "config.json")
|
||||
if os.path.exists(config_path):
|
||||
# Check if there are .safetensors files with EXL2 naming pattern
|
||||
for file in os.listdir(source):
|
||||
if file.endswith(".safetensors") and ("model" in file.lower() or "exl2" in file.lower()):
|
||||
return True
|
||||
elif source.lower().endswith(".exl2"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
|
||||
return _ExLlama2Unified(source, **kwargs)
|
||||
91
llm_runtime/loaders/llamacpp_loader.py
Normal file
91
llm_runtime/loaders/llamacpp_loader.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Any, Iterator, List
|
||||
from llm_runtime.types import UnifiedModel, GenerateConfig
|
||||
|
||||
class _LlamaCppUnified:
|
||||
def __init__(self, model_path: str, **kwargs: Any):
|
||||
# Import llama_cpp directly instead of from main.py to avoid circular imports
|
||||
from llama_cpp import Llama
|
||||
|
||||
if not model_path.lower().endswith(".gguf"):
|
||||
raise ValueError(f"Not a valid GGUF model: {model_path}")
|
||||
|
||||
self.model_path = model_path
|
||||
self.kwargs = kwargs
|
||||
self._llama = None
|
||||
|
||||
def _get_model(self):
|
||||
"""Lazy load the model using existing implementation"""
|
||||
if self._llama is None:
|
||||
# Import the _get_llama function from main module to maintain compatibility
|
||||
try:
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
from main import _get_llama
|
||||
self._llama = _get_llama(
|
||||
self.model_path,
|
||||
n_ctx=self.kwargs.get("n_ctx", 8192),
|
||||
n_gpu_layers=self.kwargs.get("n_gpu_layers", 0),
|
||||
lora_path=self.kwargs.get("lora_path"),
|
||||
n_threads=self.kwargs.get("n_threads"),
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback to direct llama-cpp-python if main import fails
|
||||
from llama_cpp import Llama
|
||||
self._llama = Llama(
|
||||
model_path=self.model_path,
|
||||
n_ctx=self.kwargs.get("n_ctx", 8192),
|
||||
n_gpu_layers=self.kwargs.get("n_gpu_layers", 0),
|
||||
verbose=self.kwargs.get("verbose", False),
|
||||
n_threads=self.kwargs.get("n_threads")
|
||||
)
|
||||
return self._llama
|
||||
|
||||
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
|
||||
llama = self._get_model()
|
||||
|
||||
# Convert GenerateConfig to llama-cpp-python format
|
||||
result = llama(
|
||||
prompt,
|
||||
max_tokens=cfg.max_tokens,
|
||||
temperature=cfg.temperature,
|
||||
top_p=cfg.top_p,
|
||||
stop=list(cfg.stop) if cfg.stop else None,
|
||||
echo=False
|
||||
)
|
||||
|
||||
return result["choices"][0]["text"]
|
||||
|
||||
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
|
||||
llama = self._get_model()
|
||||
|
||||
# Use streaming generation
|
||||
for chunk in llama.create_completion(
|
||||
prompt,
|
||||
max_tokens=cfg.max_tokens,
|
||||
temperature=cfg.temperature,
|
||||
top_p=cfg.top_p,
|
||||
stop=list(cfg.stop) if cfg.stop else None,
|
||||
stream=True,
|
||||
echo=False
|
||||
):
|
||||
text = chunk["choices"][0]["text"]
|
||||
if text:
|
||||
yield text
|
||||
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
llama = self._get_model()
|
||||
return llama.tokenize(text.encode("utf-8"), add_bos=False)
|
||||
|
||||
def detokenize(self, ids: List[int]) -> str:
|
||||
llama = self._get_model()
|
||||
return llama.detokenize(ids).decode("utf-8", errors="ignore")
|
||||
|
||||
class LlamaCppLoader:
|
||||
name = "llamacpp"
|
||||
|
||||
def can_load(self, source: str, **kwargs: Any) -> bool:
|
||||
return source.lower().endswith(".gguf")
|
||||
|
||||
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
|
||||
return _LlamaCppUnified(source, **kwargs)
|
||||
672
llm_runtime/loaders/transformers_loader.py
Normal file
672
llm_runtime/loaders/transformers_loader.py
Normal file
@@ -0,0 +1,672 @@
|
||||
from typing import Any, Iterator, List, Optional, Tuple, Dict
|
||||
import os
|
||||
import time
|
||||
from llm_runtime.types import UnifiedModel, GenerateConfig
|
||||
from llm_runtime.util_chat import apply_chat_template
|
||||
from llm_runtime.chat_session import ChatSession, get_session_manager
|
||||
from llm_runtime.device_utils import device_for_hf
|
||||
|
||||
class _HFUnified:
|
||||
def __init__(self, src: str, **kwargs: Any):
|
||||
print(f"[HF_DEBUG] _HFUnified.__init__() called with src='{src}', kwargs={kwargs}")
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
print("[HF_DEBUG] Successfully imported torch and transformers")
|
||||
except ImportError:
|
||||
print("[HF_DEBUG] Failed to import torch or transformers")
|
||||
raise ImportError("transformers, torch, and accelerate are required for HF models. Install with: pip install transformers torch accelerate safetensors")
|
||||
|
||||
self.torch = torch
|
||||
self.TextIteratorStreamer = TextIteratorStreamer
|
||||
|
||||
# Normalize device for HuggingFace
|
||||
self.device = device_for_hf(kwargs.get("device"))
|
||||
print(f"Loading HF model from: {src}")
|
||||
print(f"Using device: {self.device}")
|
||||
|
||||
# Load tokenizer
|
||||
self.tok = AutoTokenizer.from_pretrained(
|
||||
src,
|
||||
use_fast=True,
|
||||
trust_remote_code=kwargs.get("trust_remote_code", True),
|
||||
token=kwargs.get("token")
|
||||
)
|
||||
|
||||
# Prepare quantization config if requested
|
||||
quantization_config = self._prepare_quantization_config(kwargs)
|
||||
|
||||
# Prepare device mapping
|
||||
device_map, max_memory = self._prepare_device_mapping(kwargs, self.device)
|
||||
|
||||
# Load model with advanced options
|
||||
load_kwargs = {
|
||||
"torch_dtype": kwargs.get("torch_dtype", "auto"),
|
||||
"device_map": device_map,
|
||||
"trust_remote_code": kwargs.get("trust_remote_code", True),
|
||||
"low_cpu_mem_usage": True,
|
||||
"token": kwargs.get("token")
|
||||
}
|
||||
|
||||
if quantization_config:
|
||||
load_kwargs["quantization_config"] = quantization_config
|
||||
print(f"Using quantization: {kwargs.get('quantization', 'none')}")
|
||||
# When quantization is enabled, force device_map to "auto" to avoid device format conflicts
|
||||
load_kwargs["device_map"] = "auto"
|
||||
print(f"[DEBUG] Forcing device_map='auto' for quantization compatibility")
|
||||
|
||||
if max_memory:
|
||||
load_kwargs["max_memory"] = max_memory
|
||||
print(f"Memory limits: {max_memory}")
|
||||
|
||||
if kwargs.get("offload_folder"):
|
||||
load_kwargs["offload_folder"] = kwargs.get("offload_folder")
|
||||
print(f"Offloading to: {kwargs.get('offload_folder')}")
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(src, **load_kwargs)
|
||||
|
||||
# Set pad token if not present
|
||||
if self.tok.pad_token is None:
|
||||
self.tok.pad_token = self.tok.eos_token
|
||||
|
||||
# Initialize chat session support with context from kwargs
|
||||
self.session_manager = get_session_manager()
|
||||
self.current_session = None
|
||||
|
||||
# Store context size for session creation
|
||||
self.n_ctx = kwargs.get('n_ctx', 4096)
|
||||
|
||||
def _prepare_quantization_config(self, kwargs: Any):
|
||||
"""Prepare quantization configuration"""
|
||||
quantization = kwargs.get("quantization", "none")
|
||||
|
||||
if quantization == "none":
|
||||
return None
|
||||
|
||||
try:
|
||||
from transformers import BitsAndBytesConfig
|
||||
except ImportError:
|
||||
print("Warning: BitsAndBytesConfig not available, skipping quantization")
|
||||
return None
|
||||
|
||||
if quantization == "4bit":
|
||||
return BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_compute_dtype=self.torch.bfloat16,
|
||||
)
|
||||
elif quantization == "8bit":
|
||||
return BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
else:
|
||||
print(f"Warning: Unknown quantization type '{quantization}', skipping")
|
||||
return None
|
||||
|
||||
def _prepare_device_mapping(self, kwargs: Any, hf_device: str):
|
||||
"""Prepare device mapping and memory limits"""
|
||||
device_strategy = kwargs.get("device_strategy", "auto")
|
||||
gpu_memory_limit = kwargs.get("gpu_memory_limit", None)
|
||||
|
||||
device_map = "auto"
|
||||
max_memory = None
|
||||
|
||||
if device_strategy == "force_gpu":
|
||||
device_map = {"": hf_device}
|
||||
elif device_strategy == "balanced_split":
|
||||
# Use memory limits for balanced CPU/GPU split
|
||||
if gpu_memory_limit:
|
||||
# For quantization, use integer device index; for non-quantization, use string
|
||||
if kwargs.get("quantization", "none") != "none":
|
||||
gpu_device = 0 if hf_device.startswith("cuda:") else 0
|
||||
else:
|
||||
gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0"
|
||||
max_memory = {
|
||||
gpu_device: f"{gpu_memory_limit}GiB",
|
||||
"cpu": "48GiB", # Large CPU limit
|
||||
}
|
||||
device_map = "auto"
|
||||
elif device_strategy == "cpu_only":
|
||||
device_map = {"": "cpu"}
|
||||
else: # auto
|
||||
if gpu_memory_limit:
|
||||
# For quantization, use integer device index; for non-quantization, use string
|
||||
if kwargs.get("quantization", "none") != "none":
|
||||
gpu_device = 0 if hf_device.startswith("cuda:") else 0
|
||||
else:
|
||||
gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0"
|
||||
max_memory = {
|
||||
gpu_device: f"{gpu_memory_limit}GiB",
|
||||
"cpu": "32GiB",
|
||||
}
|
||||
device_map = "auto"
|
||||
|
||||
return device_map, max_memory
|
||||
|
||||
def _build_eos_ids(self, stop) -> Optional[List[int]]:
|
||||
"""Convert stop strings to token IDs"""
|
||||
if not stop:
|
||||
return None
|
||||
|
||||
ids = []
|
||||
for s in stop:
|
||||
# Handle single character stops
|
||||
if len(s) == 1:
|
||||
tid = self.tok.convert_tokens_to_ids(s)
|
||||
if tid is not None:
|
||||
ids.append(tid)
|
||||
return ids or None
|
||||
|
||||
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
|
||||
"""Generate text - uses KV cache if session_id provided, otherwise fallback to standard generation"""
|
||||
# Set defaults for None values
|
||||
if cfg.temperature is None:
|
||||
cfg.temperature = 0.7 # Default temperature
|
||||
if cfg.top_p is None:
|
||||
cfg.top_p = 0.95 # Default top_p
|
||||
if cfg.max_tokens is None:
|
||||
cfg.max_tokens = 500 # Default max tokens
|
||||
|
||||
session_id = kwargs.get('session_id', 'default') # Use 'default' session if none specified
|
||||
if session_id:
|
||||
return self.generate_with_cache(prompt, session_id, cfg, **kwargs)
|
||||
|
||||
# Fallback to standard generation without KV cache
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
inputs = self.tok(prompt, return_tensors="pt").to(self.model.device)
|
||||
|
||||
class MultiStringStop(StoppingCriteria):
|
||||
def __init__(self, toks, stops):
|
||||
self.toks, self.stops = toks, stops or []
|
||||
|
||||
def __call__(self, input_ids, scores, **_):
|
||||
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
|
||||
return any(s in text for s in self.stops)
|
||||
|
||||
with self.torch.no_grad():
|
||||
out_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=cfg.max_tokens,
|
||||
do_sample=cfg.temperature is not None and cfg.temperature > 0.0,
|
||||
temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None,
|
||||
top_p=cfg.top_p,
|
||||
eos_token_id=self._build_eos_ids(cfg.stop),
|
||||
stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None,
|
||||
pad_token_id=self.tok.pad_token_id,
|
||||
)
|
||||
|
||||
# Decode only the new tokens
|
||||
generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
return generated_text
|
||||
|
||||
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
|
||||
"""Stream text generation - uses KV cache if session_id provided, otherwise fallback to standard streaming"""
|
||||
session_id = kwargs.pop('session_id', 'default') # Remove from kwargs to avoid duplicate
|
||||
if session_id:
|
||||
yield from self.stream_with_cache(prompt, session_id, cfg, **kwargs)
|
||||
return
|
||||
|
||||
# Fallback to standard streaming without KV cache
|
||||
import threading
|
||||
|
||||
enc = self.tok(prompt, return_tensors="pt").to(self.model.device)
|
||||
streamer = self.TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
def _worker():
|
||||
with self.torch.no_grad():
|
||||
self.model.generate(
|
||||
**enc,
|
||||
max_new_tokens=cfg.max_tokens,
|
||||
do_sample=cfg.temperature is not None and cfg.temperature > 0.0,
|
||||
temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None,
|
||||
top_p=cfg.top_p,
|
||||
eos_token_id=self._build_eos_ids(cfg.stop),
|
||||
streamer=streamer,
|
||||
pad_token_id=self.tok.pad_token_id,
|
||||
)
|
||||
|
||||
t = threading.Thread(target=_worker)
|
||||
t.start()
|
||||
|
||||
for text in streamer:
|
||||
yield text
|
||||
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
return self.tok.encode(text, add_special_tokens=False)
|
||||
|
||||
def detokenize(self, ids: List[int]) -> str:
|
||||
return self.tok.decode(ids, skip_special_tokens=True)
|
||||
|
||||
def get_session(self, session_id: str = "default") -> ChatSession:
|
||||
"""Get or create a chat session for KV caching"""
|
||||
# Create session config with the correct context size
|
||||
from llm_runtime.chat_session import ChatSessionConfig
|
||||
session_config = ChatSessionConfig(max_context_length=self.n_ctx)
|
||||
return self.session_manager.get_session(session_id, config=session_config)
|
||||
|
||||
def _prefill_phase(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
|
||||
cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
|
||||
"""Prefill phase: process the full prompt and return logits + past_key_values"""
|
||||
with self.torch.no_grad():
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
return_dict=True
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
def _prefill_incremental(self, new_input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
|
||||
past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
|
||||
"""Incremental prefill: process only new tokens with existing KV cache"""
|
||||
with self.torch.no_grad():
|
||||
outputs = self.model(
|
||||
input_ids=new_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
def _decode_step(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
|
||||
past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
|
||||
"""Single decode step: generate next token with KV cache"""
|
||||
with self.torch.no_grad():
|
||||
outputs = self.model(
|
||||
input_ids=input_ids, # Should be shape [1, 1] for single token
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
def _sample_token(self, logits: 'torch.Tensor', cfg: GenerateConfig) -> int:
|
||||
"""Sample next token from logits based on generation config with optimized sampling"""
|
||||
# Get logits for the last token
|
||||
next_token_logits = logits[0, -1, :]
|
||||
|
||||
# Apply temperature scaling
|
||||
if cfg.temperature is not None and cfg.temperature != 1.0 and cfg.temperature > 0:
|
||||
next_token_logits = next_token_logits / cfg.temperature
|
||||
|
||||
# Apply top-p (nucleus) sampling if specified
|
||||
if cfg.temperature is not None and cfg.temperature > 0.0 and cfg.top_p is not None and cfg.top_p < 1.0:
|
||||
sorted_logits, sorted_indices = self.torch.sort(next_token_logits, descending=True)
|
||||
cumulative_probs = self.torch.cumsum(self.torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > cfg.top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
|
||||
next_token_logits[indices_to_remove] = float('-inf')
|
||||
|
||||
if cfg.temperature is not None and cfg.temperature > 0.0:
|
||||
# Sample from distribution
|
||||
probs = self.torch.softmax(next_token_logits, dim=-1)
|
||||
next_token = self.torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
# Greedy sampling (deterministic)
|
||||
next_token = self.torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
||||
|
||||
return next_token.item()
|
||||
|
||||
def _should_stop(self, token_id: int, generated_text: str, cfg: GenerateConfig) -> bool:
|
||||
"""Check if generation should stop"""
|
||||
# Check for EOS token
|
||||
if token_id == self.tok.eos_token_id:
|
||||
return True
|
||||
|
||||
# Check for custom stop strings
|
||||
if cfg.stop:
|
||||
for stop_str in cfg.stop:
|
||||
if stop_str in generated_text:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def generate_with_cache(self, prompt: str, session_id: str = "default",
|
||||
cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
|
||||
"""Generate text with persistent KV cache using manual generation loop"""
|
||||
# Set defaults for None values
|
||||
if cfg.temperature is None:
|
||||
cfg.temperature = 0.7 # Default temperature
|
||||
if cfg.top_p is None:
|
||||
cfg.top_p = 0.95 # Default top_p
|
||||
if cfg.max_tokens is None:
|
||||
cfg.max_tokens = 500 # Default max tokens
|
||||
|
||||
session = self.get_session(session_id)
|
||||
|
||||
# Tokenize input
|
||||
inputs = self.tok(prompt, return_tensors="pt", padding=True)
|
||||
input_ids = inputs.input_ids.to(self.model.device)
|
||||
attention_mask = inputs.attention_mask.to(self.model.device)
|
||||
|
||||
generated_tokens = []
|
||||
max_new_tokens = cfg.max_tokens
|
||||
|
||||
# Check if we need to invalidate cache with proper token validation
|
||||
if session.should_invalidate(prompt, self.tok):
|
||||
session.invalidate_cache()
|
||||
print(f"[KV_CACHE] Cache invalidated for session {session_id}")
|
||||
|
||||
# Determine if we need prefill phase
|
||||
if session.past_key_values is None:
|
||||
print(f"[KV_CACHE] Running prefill phase for {len(input_ids[0])} tokens")
|
||||
# Prefill phase: process full prompt
|
||||
logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg)
|
||||
|
||||
# Update session cache
|
||||
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
|
||||
|
||||
# Sample first token
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
|
||||
# Update input_ids and attention_mask for decode phase
|
||||
current_length = input_ids.shape[1]
|
||||
else:
|
||||
print(f"[KV_CACHE] Using cached KV state, context length: {session.context_length}")
|
||||
# Use cached state
|
||||
past_key_values = session.past_key_values
|
||||
current_length = session.context_length
|
||||
|
||||
# For cached state, we still need to process the new part of the prompt if it exists
|
||||
cached_length = len(session.cached_input_ids)
|
||||
if input_ids.shape[1] > cached_length:
|
||||
print(f"[KV_CACHE] Processing {input_ids.shape[1] - cached_length} new tokens (incremental prefill)")
|
||||
new_tokens = input_ids[:, cached_length:]
|
||||
|
||||
# Process only the new tokens with existing KV cache
|
||||
# Create attention mask that covers both cached and new tokens
|
||||
total_length = cached_length + new_tokens.shape[1]
|
||||
extended_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg)
|
||||
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
|
||||
current_length = input_ids.shape[1]
|
||||
else:
|
||||
# No new tokens to process - get initial logits for generation
|
||||
# Use a forward pass with the last token to get proper logits distribution
|
||||
last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device)
|
||||
total_length = current_length + 1
|
||||
next_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg)
|
||||
|
||||
# Sample first new token
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
|
||||
# Decode phase: generate tokens one by one with optimized attention management
|
||||
for step in range(max_new_tokens - 1): # -1 because we already generated first token
|
||||
# Check stop conditions
|
||||
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
|
||||
if self._should_stop(next_token_id, generated_text, cfg):
|
||||
break
|
||||
|
||||
# Prepare inputs for next step - only pass the new token
|
||||
next_input = self.torch.tensor([[next_token_id]], device=self.model.device)
|
||||
total_length = current_length + len(generated_tokens) + 1
|
||||
next_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
# Generate next token using cached KV states
|
||||
logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg)
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
|
||||
# Update session with final state - include all processed tokens
|
||||
final_input_ids = input_ids[0].tolist() + generated_tokens
|
||||
session.update_cache(past_key_values, final_input_ids, prompt + self.tok.decode(generated_tokens, skip_special_tokens=True))
|
||||
|
||||
# Decode generated tokens
|
||||
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
|
||||
print(f"[KV_CACHE] Generated {len(generated_tokens)} tokens with KV cache")
|
||||
|
||||
return generated_text
|
||||
|
||||
def stream_with_cache(self, prompt: str, session_id: str = "default",
|
||||
cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
|
||||
"""Stream generation with persistent KV cache"""
|
||||
session = self.get_session(session_id)
|
||||
|
||||
# Tokenize input
|
||||
inputs = self.tok(prompt, return_tensors="pt", padding=True)
|
||||
input_ids = inputs.input_ids.to(self.model.device)
|
||||
attention_mask = inputs.attention_mask.to(self.model.device)
|
||||
|
||||
generated_tokens = []
|
||||
max_new_tokens = cfg.max_tokens
|
||||
|
||||
# Check if we need to invalidate cache with proper token validation
|
||||
if session.should_invalidate(prompt, self.tok):
|
||||
session.invalidate_cache()
|
||||
print(f"[KV_CACHE] Cache invalidated for session {session_id}")
|
||||
|
||||
# Determine if we need prefill phase
|
||||
if session.past_key_values is None:
|
||||
print(f"[KV_CACHE] Streaming prefill phase for {len(input_ids[0])} tokens")
|
||||
# Prefill phase: process full prompt
|
||||
logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg)
|
||||
|
||||
# Update session cache
|
||||
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
|
||||
|
||||
# Sample first token
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
current_length = input_ids.shape[1]
|
||||
|
||||
# Yield first token
|
||||
first_text = self.tok.decode([next_token_id], skip_special_tokens=True)
|
||||
if first_text:
|
||||
yield first_text
|
||||
else:
|
||||
print(f"[KV_CACHE] Streaming with cached KV state, context length: {session.context_length}")
|
||||
# Use cached state - optimized logic similar to generate_with_cache
|
||||
past_key_values = session.past_key_values
|
||||
current_length = session.context_length
|
||||
|
||||
# Handle new tokens in prompt with incremental prefill
|
||||
cached_length = len(session.cached_input_ids)
|
||||
if input_ids.shape[1] > cached_length:
|
||||
print(f"[KV_CACHE] Streaming incremental prefill for {input_ids.shape[1] - cached_length} new tokens")
|
||||
new_tokens = input_ids[:, cached_length:]
|
||||
|
||||
# Process only new tokens with existing KV cache
|
||||
total_length = cached_length + new_tokens.shape[1]
|
||||
extended_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg)
|
||||
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
|
||||
current_length = input_ids.shape[1]
|
||||
else:
|
||||
# No new tokens - use last cached token for initial generation
|
||||
last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device)
|
||||
total_length = current_length + 1
|
||||
next_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg)
|
||||
|
||||
# Sample first token
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
|
||||
# Yield first token
|
||||
first_text = self.tok.decode([next_token_id], skip_special_tokens=True)
|
||||
if first_text:
|
||||
yield first_text
|
||||
|
||||
# Decode phase: stream tokens one by one with optimized KV cache usage
|
||||
for step in range(max_new_tokens - 1):
|
||||
# Check stop conditions
|
||||
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
|
||||
if self._should_stop(next_token_id, generated_text, cfg):
|
||||
break
|
||||
|
||||
# Prepare inputs for next step - efficient single token processing
|
||||
next_input = self.torch.tensor([[next_token_id]], device=self.model.device)
|
||||
total_length = current_length + len(generated_tokens) + 1
|
||||
next_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
# Generate next token using cached KV states
|
||||
logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg)
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
|
||||
# Yield the new token immediately
|
||||
token_text = self.tok.decode([next_token_id], skip_special_tokens=True)
|
||||
if token_text:
|
||||
yield token_text
|
||||
|
||||
# Update session with final state - include all processed tokens
|
||||
final_input_ids = input_ids[0].tolist() + generated_tokens
|
||||
full_generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
|
||||
session.update_cache(past_key_values, final_input_ids, prompt + full_generated_text)
|
||||
|
||||
print(f"[KV_CACHE] Streamed {len(generated_tokens)} tokens with persistent KV cache")
|
||||
|
||||
def clear_session_cache(self, session_id: str = "default") -> None:
|
||||
"""Clear KV cache for a specific session"""
|
||||
session = self.get_session(session_id)
|
||||
session.invalidate_cache()
|
||||
print(f"[DEBUG] Cleared cache for session {session_id}")
|
||||
|
||||
def get_session_info(self, session_id: str = "default") -> Dict[str, Any]:
|
||||
"""Get information about a chat session"""
|
||||
session = self.get_session(session_id)
|
||||
return session.get_cache_info()
|
||||
|
||||
def add_conversation_turn(self, user_message: str, assistant_message: str,
|
||||
session_id: str = "default") -> None:
|
||||
"""Add a complete conversation turn to the session history"""
|
||||
session = self.get_session(session_id)
|
||||
session.add_message("user", user_message)
|
||||
session.add_message("assistant", assistant_message)
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive information about the loaded model"""
|
||||
try:
|
||||
# Basic model info
|
||||
info = {
|
||||
"model_name": getattr(self.model.config, 'name_or_path', 'unknown'),
|
||||
"model_type": self.model.config.model_type,
|
||||
"vocab_size": self.model.config.vocab_size,
|
||||
"device": str(self.device),
|
||||
"dtype": str(self.model.dtype),
|
||||
"supports_kv_cache": True,
|
||||
"max_position_embeddings": getattr(self.model.config, 'max_position_embeddings', 'unknown'),
|
||||
"torch_compile_enabled": hasattr(self.model, '_orig_mod')
|
||||
}
|
||||
|
||||
# Memory info
|
||||
if self.torch.cuda.is_available() and str(self.device) != 'cpu':
|
||||
info.update({
|
||||
"gpu_memory_allocated": f"{self.torch.cuda.memory_allocated() / 1024**3:.2f} GB",
|
||||
"gpu_memory_reserved": f"{self.torch.cuda.memory_reserved() / 1024**3:.2f} GB",
|
||||
"gpu_memory_total": f"{self.torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB"
|
||||
})
|
||||
|
||||
# Model parameters
|
||||
total_params = sum(p.numel() for p in self.model.parameters())
|
||||
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
info.update({
|
||||
"total_parameters": f"{total_params:,}",
|
||||
"trainable_parameters": f"{trainable_params:,}",
|
||||
"model_size_mb": f"{total_params * 4 / 1024**2:.2f} MB" # Assuming float32
|
||||
})
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
return {"error": f"Could not get model info: {e}", "supports_kv_cache": True}
|
||||
|
||||
def get_kv_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get KV cache statistics across all sessions"""
|
||||
try:
|
||||
sessions = self.session_manager.get_all_sessions()
|
||||
stats = {
|
||||
"total_sessions": len(sessions),
|
||||
"active_sessions": 0,
|
||||
"total_cached_tokens": 0,
|
||||
"memory_usage_estimate": "0 MB"
|
||||
}
|
||||
|
||||
for session_id in sessions:
|
||||
session = self.session_manager.get_session(session_id)
|
||||
if session.past_key_values is not None:
|
||||
stats["active_sessions"] += 1
|
||||
stats["total_cached_tokens"] += session.context_length
|
||||
|
||||
# Rough estimate of KV cache memory usage
|
||||
# Each token in KV cache roughly uses: hidden_size * num_layers * 2 (key + value) * 4 bytes (float32)
|
||||
if stats["total_cached_tokens"] > 0:
|
||||
try:
|
||||
hidden_size = self.model.config.hidden_size
|
||||
num_layers = self.model.config.num_hidden_layers
|
||||
memory_bytes = stats["total_cached_tokens"] * hidden_size * num_layers * 2 * 4
|
||||
stats["memory_usage_estimate"] = f"{memory_bytes / 1024**2:.2f} MB"
|
||||
except:
|
||||
pass
|
||||
|
||||
return stats
|
||||
except Exception as e:
|
||||
return {"error": f"Could not get KV cache stats: {e}"}
|
||||
|
||||
def warm_up_model(self, test_prompt: str = "Hello") -> Dict[str, float]:
|
||||
"""Warm up the model and measure performance metrics"""
|
||||
try:
|
||||
print("[KV_CACHE] Warming up model...")
|
||||
start_time = time.time()
|
||||
|
||||
# Simple generation to warm up CUDA kernels
|
||||
cfg = GenerateConfig(max_tokens=5, temperature=0.0)
|
||||
_ = self.generate(test_prompt, cfg=cfg, session_id="warmup")
|
||||
|
||||
warmup_time = time.time() - start_time
|
||||
|
||||
# Clean up warmup session
|
||||
self.clear_session_cache("warmup")
|
||||
|
||||
return {
|
||||
"warmup_time": warmup_time,
|
||||
"status": "success"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"warmup_time": 0.0,
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
class HFTransformersLoader:
|
||||
name = "hf"
|
||||
|
||||
def can_load(self, source: str, **kwargs: Any) -> bool:
|
||||
# Accept HF repo-id or local dir with config.json (covers .safetensors)
|
||||
if "/" in source and not os.path.exists(source): # repo-id like "microsoft/DialoGPT-medium"
|
||||
return True
|
||||
|
||||
# Check if it's a directory with config.json
|
||||
if os.path.isdir(source) and os.path.exists(os.path.join(source, "config.json")):
|
||||
return True
|
||||
|
||||
# If it's a single .safetensors file, look for config.json in the same directory
|
||||
if source.lower().endswith('.safetensors'):
|
||||
parent_dir = os.path.dirname(source)
|
||||
return os.path.exists(os.path.join(parent_dir, "config.json"))
|
||||
|
||||
# Also support .bin files (PyTorch checkpoints)
|
||||
if source.lower().endswith('.bin'):
|
||||
parent_dir = os.path.dirname(source)
|
||||
return os.path.exists(os.path.join(parent_dir, "config.json"))
|
||||
|
||||
return False
|
||||
|
||||
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
|
||||
return _HFUnified(source, **kwargs)
|
||||
55
llm_runtime/model_router.py
Normal file
55
llm_runtime/model_router.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
import os, json
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Dict, Tuple, Callable
|
||||
|
||||
# Only support Hugging Face (Safetensors) and GGUF
|
||||
LoaderKind = Literal["hf", "gguf"]
|
||||
|
||||
def detect_loader_type(source: str) -> Tuple[LoaderKind, str]:
|
||||
"""
|
||||
Decide which loader to use based on the path/repo.
|
||||
Returns: (kind, reason)
|
||||
kind ∈ {"hf","gguf"}
|
||||
"""
|
||||
print(f"[ROUTER_DEBUG] detect_loader_type() called with source: '{source}'")
|
||||
p = Path(source)
|
||||
low = source.lower()
|
||||
|
||||
# 0) Force HF for official Meta repos
|
||||
if not p.exists() and low.startswith("meta-llama/"):
|
||||
result = ("hf", "Official Meta repo (full-precision).")
|
||||
print(f"[ROUTER_DEBUG] Meta repo detected -> {result}")
|
||||
return result
|
||||
|
||||
# 1) Local FILE
|
||||
if p.exists() and p.is_file():
|
||||
if p.suffix.lower() == ".gguf":
|
||||
return "gguf", "Local .gguf file."
|
||||
return "hf", "Local non-.gguf file (default HF)."
|
||||
|
||||
# 2) Local DIR
|
||||
if p.exists() and p.is_dir():
|
||||
# GGUF hint: any *.gguf inside
|
||||
if any(p.glob("*.gguf")):
|
||||
return "gguf", "Directory contains .gguf file(s)."
|
||||
# Inspect config.json only for GGUF hints; all else defaults to HF
|
||||
cfg = p / "config.json"
|
||||
if cfg.exists():
|
||||
try:
|
||||
data = json.loads(cfg.read_text(encoding="utf-8"))
|
||||
text = json.dumps(data).lower()
|
||||
if ("gguf" in text) or ("llama.cpp" in text) or ("ggml" in text):
|
||||
return "gguf", "config.json mentions GGUF/llama.cpp."
|
||||
except Exception:
|
||||
pass
|
||||
return "hf", "Default HF for non-GGUF directories."
|
||||
|
||||
# 3) Remote repo style (org/name)
|
||||
if ("/" in source or source.count("\\") == 1) and not p.exists():
|
||||
if any(tag in low for tag in ["gguf", "ggml", "llama.cpp"]):
|
||||
return "gguf", "Repo name suggests GGUF."
|
||||
return "hf", "Remote repo (default HF)."
|
||||
|
||||
# 4) Fallback
|
||||
return "hf", "Fallback to HF."
|
||||
10
llm_runtime/registry.py
Normal file
10
llm_runtime/registry.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Any
|
||||
from .types import UnifiedModel
|
||||
from .loader_factory import load_model_for_gui
|
||||
|
||||
def load_model(source: str, **kwargs: Any) -> UnifiedModel:
|
||||
"""Load model using the router-based factory system"""
|
||||
print(f"[REGISTRY_DEBUG] load_model() called with source='{source}', kwargs={kwargs}")
|
||||
model, kind, reason = load_model_for_gui(source, **kwargs)
|
||||
print(f"[REGISTRY_DEBUG] load_model_for_gui() returned: kind='{kind}', reason='{reason}'")
|
||||
return model
|
||||
23
llm_runtime/types.py
Normal file
23
llm_runtime/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Iterator, List, Optional, Protocol, Any
|
||||
|
||||
@dataclass
|
||||
class GenerateConfig:
|
||||
# None means "use model defaults"
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
stop: Optional[Iterable[str]] = None
|
||||
# Optional context window hint; models may ignore if they manage it internally
|
||||
context_window: Optional[int] = None
|
||||
# Optional advanced knobs; models may ignore if unsupported
|
||||
min_p: Optional[float] = None
|
||||
typical_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
|
||||
class UnifiedModel(Protocol):
|
||||
# Default config uses None for all tunables so the model can choose
|
||||
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str: ...
|
||||
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]: ...
|
||||
def tokenize(self, text: str) -> List[int]: ...
|
||||
def detokenize(self, ids: List[int]) -> str: ...
|
||||
23
llm_runtime/util_chat.py
Normal file
23
llm_runtime/util_chat.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
|
||||
"""Apply chat template to messages using tokenizer or fallback to ChatML format"""
|
||||
if hasattr(tokenizer, "apply_chat_template"):
|
||||
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
|
||||
|
||||
# Fallback: simple ChatML-ish format
|
||||
parts = []
|
||||
for m in messages:
|
||||
role = m.get("role", "user")
|
||||
content = m.get("content", "")
|
||||
if role == "system":
|
||||
parts.append(f"<|system|>\n{content}\n")
|
||||
elif role == "assistant":
|
||||
parts.append(f"<|assistant|>\n{content}\n")
|
||||
else:
|
||||
parts.append(f"<|user|>\n{content}\n")
|
||||
|
||||
if add_generation_prompt:
|
||||
parts.append("<|assistant|>\n")
|
||||
|
||||
return "".join(parts)
|
||||
Reference in New Issue
Block a user