first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

View File

@@ -0,0 +1,176 @@
from typing import Any, Iterator, List, Optional
import os, json
import torch
from llm_runtime.types import UnifiedModel, GenerateConfig
from llm_runtime.device_utils import device_for_gptq
def _inputs_device_from_gptq_device(dev) -> torch.device:
"""
AutoGPTQ wants: int GPU index | 'cpu' | 'mps' | 'disk'
But tokenizer tensors need a torch.device:
- int -> 'cuda:{idx}'
- 'cpu' -> 'cpu'
- 'mps' -> 'mps'
- 'disk' -> still run the forward on CUDA/CPU; safest default: 'cpu'
"""
if isinstance(dev, int):
return torch.device(f"cuda:{dev}")
if isinstance(dev, str):
if dev in ("cpu", "mps"):
return torch.device(dev)
if dev == "disk":
# inputs live on CPU; model will page as needed
return torch.device("cpu")
# Fallback
return torch.device("cpu")
class _GPTQUnified:
def __init__(self, src: str, **kwargs: Any):
print(f"[GPTQ_DEBUG] _GPTQUnified.__init__() called with src='{src}', kwargs={kwargs}")
try:
from auto_gptq import AutoGPTQForCausalLM
from transformers import AutoTokenizer
print("[GPTQ_DEBUG] Successfully imported auto_gptq and transformers")
except ImportError as e:
print(f"[GPTQ_DEBUG] Failed to import auto_gptq or transformers: {e}")
raise ImportError("auto-gptq and transformers are required for GPTQ models. Install with: pip install auto-gptq transformers")
print(f"[GPTQ_DEBUG] Loading GPTQ model from: {src}")
# Normalize device for AutoGPTQ (int or 'cpu'/'mps'/'disk')
raw_device = kwargs.get("device")
print(f"[GPTQ_DEBUG] Raw device from kwargs: {raw_device}")
self._gptq_dev = device_for_gptq(raw_device)
print(f"[GPTQ_DEBUG] Normalized device for GPTQ: {self._gptq_dev}")
self._inputs_device = _inputs_device_from_gptq_device(self._gptq_dev)
print(f"[GPTQ_DEBUG] Using device (gptq): {self._gptq_dev} | inputs will go to: {self._inputs_device}")
trust_remote = kwargs.get("trust_remote_code", True)
token = kwargs.get("token")
# Tokenizer
self.tok = AutoTokenizer.from_pretrained(
src,
use_fast=True,
trust_remote_code=trust_remote,
token=token
)
# Model
self.model = AutoGPTQForCausalLM.from_quantized(
src,
device=self._gptq_dev, # int or 'cpu'/'mps'/'disk'
trust_remote_code=trust_remote,
use_safetensors=True,
use_triton=kwargs.get("use_triton", False),
token=token
)
# Pad token safety
if self.tok.pad_token is None:
self.tok.pad_token = self.tok.eos_token
def _build_eos_ids(self, stop) -> Optional[List[int]]:
"""Encode stop strings to token IDs (take the first token of each stop string)."""
if not stop:
return None
out: List[int] = []
for s in stop:
if not s:
continue
ids = self.tok.encode(s, add_special_tokens=False)
if ids:
out.append(ids[0])
return out or None
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
from transformers import StoppingCriteria, StoppingCriteriaList
enc = self.tok(prompt, return_tensors="pt").to(self._inputs_device)
class MultiStringStop(StoppingCriteria):
def __init__(self, toks, stops):
self.toks, self.stops = toks, stops or []
def __call__(self, input_ids, scores, **_):
# Simple but effective; for high perf, implement a token-level matcher.
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
return any(s in text for s in self.stops)
do_sample = cfg.temperature is not None and cfg.temperature > 0.0
gen_kwargs = dict(
max_new_tokens=cfg.max_tokens,
do_sample=do_sample,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
pad_token_id=self.tok.pad_token_id,
)
if do_sample:
gen_kwargs["temperature"] = float(cfg.temperature)
if cfg.stop:
gen_kwargs["stopping_criteria"] = StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)])
out_ids = self.model.generate(**enc, **gen_kwargs)
# Decode only new tokens
new_tokens = out_ids[0][enc.input_ids.shape[1]:]
return self.tok.decode(new_tokens, skip_special_tokens=True)
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
import threading
from transformers import TextIteratorStreamer
enc = self.tok(prompt, return_tensors="pt").to(self._inputs_device)
streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
do_sample = cfg.temperature is not None and cfg.temperature > 0.0
gen_kwargs = dict(
max_new_tokens=cfg.max_tokens,
do_sample=do_sample,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
streamer=streamer,
pad_token_id=self.tok.pad_token_id,
)
if do_sample:
gen_kwargs["temperature"] = float(cfg.temperature)
def _worker():
self.model.generate(**enc, **gen_kwargs)
t = threading.Thread(target=_worker, daemon=True)
t.start()
for chunk in streamer:
yield chunk
def tokenize(self, text: str) -> List[int]:
return self.tok.encode(text, add_special_tokens=False)
def detokenize(self, ids: List[int]) -> str:
return self.tok.decode(ids, skip_special_tokens=True)
class AutoGPTQLoader:
name = "gptq"
def can_load(self, source: str, **kwargs: Any) -> bool:
# Local folder?
if os.path.isdir(source):
# quantize_config.json is a strong GPTQ signal
qc = os.path.join(source, "quantize_config.json")
if os.path.exists(qc):
return True
# Fallback: peek at config.json
try:
with open(os.path.join(source, "config.json"), "r", encoding="utf-8") as f:
cfg = json.load(f)
text = json.dumps(cfg).lower()
return ("gptq" in text) or ("quantize_config" in text) or ("quant_method" in text)
except Exception:
pass
# HF repo path (heuristic): let the loader try
elif "/" in source and not os.path.exists(source):
return True
return False
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _GPTQUnified(source, **kwargs)

View File

@@ -0,0 +1,132 @@
from typing import Any, Iterator, List, Optional
import os
from llm_runtime.types import UnifiedModel, GenerateConfig
class _AWQUnified:
def __init__(self, src: str, **kwargs: Any):
try:
from autoawq import AutoAWQForCausalLM
from transformers import AutoTokenizer
except ImportError:
raise ImportError("autoawq and transformers are required for AWQ models. Install with: pip install autoawq transformers")
print(f"Loading AWQ model from: {src}")
# Load tokenizer
self.tok = AutoTokenizer.from_pretrained(
src,
use_fast=True,
trust_remote_code=kwargs.get("trust_remote_code", True)
)
# Load AWQ model
self.model = AutoAWQForCausalLM.from_quantized(
src,
device_map=kwargs.get("device_map", "auto"),
trust_remote_code=kwargs.get("trust_remote_code", True),
safetensors=True,
)
# Set pad token if not present
if self.tok.pad_token is None:
self.tok.pad_token = self.tok.eos_token
def _build_eos_ids(self, stop) -> Optional[List[int]]:
"""Convert stop strings to token IDs"""
if not stop:
return None
ids = []
for s in stop:
if len(s) == 1:
tid = self.tok.convert_tokens_to_ids(s)
if tid is not None:
ids.append(tid)
return ids or None
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
from transformers import StoppingCriteria, StoppingCriteriaList
inputs = self.tok(prompt, return_tensors="pt").to(self.model.device)
class MultiStringStop(StoppingCriteria):
def __init__(self, toks, stops):
self.toks, self.stops = toks, stops or []
def __call__(self, input_ids, scores, **_):
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
return any(s in text for s in self.stops)
out_ids = self.model.generate(
**inputs,
max_new_tokens=cfg.max_tokens,
do_sample=cfg.temperature > 0.0,
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None,
pad_token_id=self.tok.pad_token_id,
)
# Decode only the new tokens
generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return generated_text
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
import threading
from transformers import TextIteratorStreamer
enc = self.tok(prompt, return_tensors="pt").to(self.model.device)
streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
def _worker():
self.model.generate(
**enc,
max_new_tokens=cfg.max_tokens,
do_sample=cfg.temperature > 0.0,
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
streamer=streamer,
pad_token_id=self.tok.pad_token_id,
)
t = threading.Thread(target=_worker)
t.start()
for text in streamer:
yield text
def tokenize(self, text: str) -> List[int]:
return self.tok.encode(text, add_special_tokens=False)
def detokenize(self, ids: List[int]) -> str:
return self.tok.decode(ids, skip_special_tokens=True)
class AWQLoader:
name = "awq"
def can_load(self, source: str, **kwargs: Any) -> bool:
# Check for AWQ indicators
if os.path.isdir(source):
# Look for AWQ indicators in config.json
try:
import json
config_path = os.path.join(source, "config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
config_str = str(config).lower()
return ("awq" in config_str or
"quantization_config" in config and
"awq" in str(config.get("quantization_config", {})).lower())
except:
pass
elif "/" in source and not os.path.exists(source): # HF repo
# For HF repos, we'll let it try if it looks like AWQ
return "awq" in source.lower()
return False
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _AWQUnified(source, **kwargs)

View File

@@ -0,0 +1,137 @@
from typing import Any, Iterator, List, Optional
import os
from llm_runtime.types import UnifiedModel, GenerateConfig
class _ExLlama2Unified:
def __init__(self, src: str, **kwargs: Any):
try:
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer, ExLlamaV2Cache
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
except ImportError:
raise ImportError("exllamav2 is required for EXL2 models. Install with: pip install exllamav2")
print(f"Loading EXL2 model from: {src}")
# Configure model
self.config = ExLlamaV2Config(src)
# Apply any config overrides
if "max_seq_len" in kwargs:
self.config.max_seq_len = kwargs["max_seq_len"]
if "scale_pos_emb" in kwargs:
self.config.scale_pos_emb = kwargs["scale_pos_emb"]
if "scale_alpha_value" in kwargs:
self.config.scale_alpha_value = kwargs["scale_alpha_value"]
# Initialize model
self.model = ExLlamaV2(self.config)
# Load model weights
self.model.load()
# Initialize tokenizer
self.tokenizer = ExLlamaV2Tokenizer(self.config)
# Initialize cache
self.cache = ExLlamaV2Cache(self.model, lazy=kwargs.get("lazy_cache", True))
# Initialize generator for streaming
self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer)
print(f"EXL2 model loaded successfully")
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
# Create sampler settings
settings = ExLlamaV2Sampler.Settings()
settings.temperature = cfg.temperature
settings.top_p = cfg.top_p
# Set stop conditions
if cfg.stop:
# ExLlamaV2 expects stop strings as a list
stop_conditions = list(cfg.stop)
else:
stop_conditions = []
# Generate text
output = self.generator.generate_simple(
prompt=prompt,
max_new_tokens=cfg.max_tokens,
seed=kwargs.get("seed", -1),
token_healing=kwargs.get("token_healing", True),
temperature=cfg.temperature,
top_p=cfg.top_p,
stop_conditions=stop_conditions,
)
return output
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
# Create sampler settings
settings = ExLlamaV2Sampler.Settings()
settings.temperature = cfg.temperature
settings.top_p = cfg.top_p
# Set stop conditions
if cfg.stop:
stop_conditions = list(cfg.stop)
else:
stop_conditions = []
# Begin streaming generation
input_ids = self.tokenizer.encode(prompt)
self.generator.begin_stream(
input_ids=input_ids,
gen_settings=settings,
token_healing=kwargs.get("token_healing", True),
seed=kwargs.get("seed", -1),
)
generated_tokens = 0
while generated_tokens < cfg.max_tokens:
chunk, eos, tokens = self.generator.stream()
if chunk:
yield chunk
generated_tokens += tokens
# Check for stop conditions
if eos:
break
if cfg.stop:
# Check if any stop condition is met in the generated text so far
current_text = self.generator.sequence_str
if any(stop in current_text for stop in cfg.stop):
break
def tokenize(self, text: str) -> List[int]:
return self.tokenizer.encode(text).tolist()
def detokenize(self, ids: List[int]) -> str:
import torch
tensor_ids = torch.tensor([ids], dtype=torch.long)
return self.tokenizer.decode(tensor_ids)[0]
class ExLlama2Loader:
name = "exllama2"
def can_load(self, source: str, **kwargs: Any) -> bool:
# Check for EXL2 model directory structure
if os.path.isdir(source):
# Look for config.json and .safetensors files that indicate EXL2
config_path = os.path.join(source, "config.json")
if os.path.exists(config_path):
# Check if there are .safetensors files with EXL2 naming pattern
for file in os.listdir(source):
if file.endswith(".safetensors") and ("model" in file.lower() or "exl2" in file.lower()):
return True
elif source.lower().endswith(".exl2"):
return True
return False
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _ExLlama2Unified(source, **kwargs)

View File

@@ -0,0 +1,91 @@
from typing import Any, Iterator, List
from llm_runtime.types import UnifiedModel, GenerateConfig
class _LlamaCppUnified:
def __init__(self, model_path: str, **kwargs: Any):
# Import llama_cpp directly instead of from main.py to avoid circular imports
from llama_cpp import Llama
if not model_path.lower().endswith(".gguf"):
raise ValueError(f"Not a valid GGUF model: {model_path}")
self.model_path = model_path
self.kwargs = kwargs
self._llama = None
def _get_model(self):
"""Lazy load the model using existing implementation"""
if self._llama is None:
# Import the _get_llama function from main module to maintain compatibility
try:
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from main import _get_llama
self._llama = _get_llama(
self.model_path,
n_ctx=self.kwargs.get("n_ctx", 8192),
n_gpu_layers=self.kwargs.get("n_gpu_layers", 0),
lora_path=self.kwargs.get("lora_path"),
n_threads=self.kwargs.get("n_threads"),
)
except ImportError:
# Fallback to direct llama-cpp-python if main import fails
from llama_cpp import Llama
self._llama = Llama(
model_path=self.model_path,
n_ctx=self.kwargs.get("n_ctx", 8192),
n_gpu_layers=self.kwargs.get("n_gpu_layers", 0),
verbose=self.kwargs.get("verbose", False),
n_threads=self.kwargs.get("n_threads")
)
return self._llama
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
llama = self._get_model()
# Convert GenerateConfig to llama-cpp-python format
result = llama(
prompt,
max_tokens=cfg.max_tokens,
temperature=cfg.temperature,
top_p=cfg.top_p,
stop=list(cfg.stop) if cfg.stop else None,
echo=False
)
return result["choices"][0]["text"]
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
llama = self._get_model()
# Use streaming generation
for chunk in llama.create_completion(
prompt,
max_tokens=cfg.max_tokens,
temperature=cfg.temperature,
top_p=cfg.top_p,
stop=list(cfg.stop) if cfg.stop else None,
stream=True,
echo=False
):
text = chunk["choices"][0]["text"]
if text:
yield text
def tokenize(self, text: str) -> List[int]:
llama = self._get_model()
return llama.tokenize(text.encode("utf-8"), add_bos=False)
def detokenize(self, ids: List[int]) -> str:
llama = self._get_model()
return llama.detokenize(ids).decode("utf-8", errors="ignore")
class LlamaCppLoader:
name = "llamacpp"
def can_load(self, source: str, **kwargs: Any) -> bool:
return source.lower().endswith(".gguf")
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _LlamaCppUnified(source, **kwargs)

View File

@@ -0,0 +1,672 @@
from typing import Any, Iterator, List, Optional, Tuple, Dict
import os
import time
from llm_runtime.types import UnifiedModel, GenerateConfig
from llm_runtime.util_chat import apply_chat_template
from llm_runtime.chat_session import ChatSession, get_session_manager
from llm_runtime.device_utils import device_for_hf
class _HFUnified:
def __init__(self, src: str, **kwargs: Any):
print(f"[HF_DEBUG] _HFUnified.__init__() called with src='{src}', kwargs={kwargs}")
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
print("[HF_DEBUG] Successfully imported torch and transformers")
except ImportError:
print("[HF_DEBUG] Failed to import torch or transformers")
raise ImportError("transformers, torch, and accelerate are required for HF models. Install with: pip install transformers torch accelerate safetensors")
self.torch = torch
self.TextIteratorStreamer = TextIteratorStreamer
# Normalize device for HuggingFace
self.device = device_for_hf(kwargs.get("device"))
print(f"Loading HF model from: {src}")
print(f"Using device: {self.device}")
# Load tokenizer
self.tok = AutoTokenizer.from_pretrained(
src,
use_fast=True,
trust_remote_code=kwargs.get("trust_remote_code", True),
token=kwargs.get("token")
)
# Prepare quantization config if requested
quantization_config = self._prepare_quantization_config(kwargs)
# Prepare device mapping
device_map, max_memory = self._prepare_device_mapping(kwargs, self.device)
# Load model with advanced options
load_kwargs = {
"torch_dtype": kwargs.get("torch_dtype", "auto"),
"device_map": device_map,
"trust_remote_code": kwargs.get("trust_remote_code", True),
"low_cpu_mem_usage": True,
"token": kwargs.get("token")
}
if quantization_config:
load_kwargs["quantization_config"] = quantization_config
print(f"Using quantization: {kwargs.get('quantization', 'none')}")
# When quantization is enabled, force device_map to "auto" to avoid device format conflicts
load_kwargs["device_map"] = "auto"
print(f"[DEBUG] Forcing device_map='auto' for quantization compatibility")
if max_memory:
load_kwargs["max_memory"] = max_memory
print(f"Memory limits: {max_memory}")
if kwargs.get("offload_folder"):
load_kwargs["offload_folder"] = kwargs.get("offload_folder")
print(f"Offloading to: {kwargs.get('offload_folder')}")
self.model = AutoModelForCausalLM.from_pretrained(src, **load_kwargs)
# Set pad token if not present
if self.tok.pad_token is None:
self.tok.pad_token = self.tok.eos_token
# Initialize chat session support with context from kwargs
self.session_manager = get_session_manager()
self.current_session = None
# Store context size for session creation
self.n_ctx = kwargs.get('n_ctx', 4096)
def _prepare_quantization_config(self, kwargs: Any):
"""Prepare quantization configuration"""
quantization = kwargs.get("quantization", "none")
if quantization == "none":
return None
try:
from transformers import BitsAndBytesConfig
except ImportError:
print("Warning: BitsAndBytesConfig not available, skipping quantization")
return None
if quantization == "4bit":
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=self.torch.bfloat16,
)
elif quantization == "8bit":
return BitsAndBytesConfig(
load_in_8bit=True,
)
else:
print(f"Warning: Unknown quantization type '{quantization}', skipping")
return None
def _prepare_device_mapping(self, kwargs: Any, hf_device: str):
"""Prepare device mapping and memory limits"""
device_strategy = kwargs.get("device_strategy", "auto")
gpu_memory_limit = kwargs.get("gpu_memory_limit", None)
device_map = "auto"
max_memory = None
if device_strategy == "force_gpu":
device_map = {"": hf_device}
elif device_strategy == "balanced_split":
# Use memory limits for balanced CPU/GPU split
if gpu_memory_limit:
# For quantization, use integer device index; for non-quantization, use string
if kwargs.get("quantization", "none") != "none":
gpu_device = 0 if hf_device.startswith("cuda:") else 0
else:
gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0"
max_memory = {
gpu_device: f"{gpu_memory_limit}GiB",
"cpu": "48GiB", # Large CPU limit
}
device_map = "auto"
elif device_strategy == "cpu_only":
device_map = {"": "cpu"}
else: # auto
if gpu_memory_limit:
# For quantization, use integer device index; for non-quantization, use string
if kwargs.get("quantization", "none") != "none":
gpu_device = 0 if hf_device.startswith("cuda:") else 0
else:
gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0"
max_memory = {
gpu_device: f"{gpu_memory_limit}GiB",
"cpu": "32GiB",
}
device_map = "auto"
return device_map, max_memory
def _build_eos_ids(self, stop) -> Optional[List[int]]:
"""Convert stop strings to token IDs"""
if not stop:
return None
ids = []
for s in stop:
# Handle single character stops
if len(s) == 1:
tid = self.tok.convert_tokens_to_ids(s)
if tid is not None:
ids.append(tid)
return ids or None
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
"""Generate text - uses KV cache if session_id provided, otherwise fallback to standard generation"""
# Set defaults for None values
if cfg.temperature is None:
cfg.temperature = 0.7 # Default temperature
if cfg.top_p is None:
cfg.top_p = 0.95 # Default top_p
if cfg.max_tokens is None:
cfg.max_tokens = 500 # Default max tokens
session_id = kwargs.get('session_id', 'default') # Use 'default' session if none specified
if session_id:
return self.generate_with_cache(prompt, session_id, cfg, **kwargs)
# Fallback to standard generation without KV cache
from transformers import StoppingCriteria, StoppingCriteriaList
inputs = self.tok(prompt, return_tensors="pt").to(self.model.device)
class MultiStringStop(StoppingCriteria):
def __init__(self, toks, stops):
self.toks, self.stops = toks, stops or []
def __call__(self, input_ids, scores, **_):
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
return any(s in text for s in self.stops)
with self.torch.no_grad():
out_ids = self.model.generate(
**inputs,
max_new_tokens=cfg.max_tokens,
do_sample=cfg.temperature is not None and cfg.temperature > 0.0,
temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None,
pad_token_id=self.tok.pad_token_id,
)
# Decode only the new tokens
generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return generated_text
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
"""Stream text generation - uses KV cache if session_id provided, otherwise fallback to standard streaming"""
session_id = kwargs.pop('session_id', 'default') # Remove from kwargs to avoid duplicate
if session_id:
yield from self.stream_with_cache(prompt, session_id, cfg, **kwargs)
return
# Fallback to standard streaming without KV cache
import threading
enc = self.tok(prompt, return_tensors="pt").to(self.model.device)
streamer = self.TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
def _worker():
with self.torch.no_grad():
self.model.generate(
**enc,
max_new_tokens=cfg.max_tokens,
do_sample=cfg.temperature is not None and cfg.temperature > 0.0,
temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
streamer=streamer,
pad_token_id=self.tok.pad_token_id,
)
t = threading.Thread(target=_worker)
t.start()
for text in streamer:
yield text
def tokenize(self, text: str) -> List[int]:
return self.tok.encode(text, add_special_tokens=False)
def detokenize(self, ids: List[int]) -> str:
return self.tok.decode(ids, skip_special_tokens=True)
def get_session(self, session_id: str = "default") -> ChatSession:
"""Get or create a chat session for KV caching"""
# Create session config with the correct context size
from llm_runtime.chat_session import ChatSessionConfig
session_config = ChatSessionConfig(max_context_length=self.n_ctx)
return self.session_manager.get_session(session_id, config=session_config)
def _prefill_phase(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
"""Prefill phase: process the full prompt and return logits + past_key_values"""
with self.torch.no_grad():
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=True,
return_dict=True
)
return outputs.logits, outputs.past_key_values
def _prefill_incremental(self, new_input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
"""Incremental prefill: process only new tokens with existing KV cache"""
with self.torch.no_grad():
outputs = self.model(
input_ids=new_input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
return outputs.logits, outputs.past_key_values
def _decode_step(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
"""Single decode step: generate next token with KV cache"""
with self.torch.no_grad():
outputs = self.model(
input_ids=input_ids, # Should be shape [1, 1] for single token
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
return outputs.logits, outputs.past_key_values
def _sample_token(self, logits: 'torch.Tensor', cfg: GenerateConfig) -> int:
"""Sample next token from logits based on generation config with optimized sampling"""
# Get logits for the last token
next_token_logits = logits[0, -1, :]
# Apply temperature scaling
if cfg.temperature is not None and cfg.temperature != 1.0 and cfg.temperature > 0:
next_token_logits = next_token_logits / cfg.temperature
# Apply top-p (nucleus) sampling if specified
if cfg.temperature is not None and cfg.temperature > 0.0 and cfg.top_p is not None and cfg.top_p < 1.0:
sorted_logits, sorted_indices = self.torch.sort(next_token_logits, descending=True)
cumulative_probs = self.torch.cumsum(self.torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > cfg.top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
if cfg.temperature is not None and cfg.temperature > 0.0:
# Sample from distribution
probs = self.torch.softmax(next_token_logits, dim=-1)
next_token = self.torch.multinomial(probs, num_samples=1)
else:
# Greedy sampling (deterministic)
next_token = self.torch.argmax(next_token_logits, dim=-1, keepdim=True)
return next_token.item()
def _should_stop(self, token_id: int, generated_text: str, cfg: GenerateConfig) -> bool:
"""Check if generation should stop"""
# Check for EOS token
if token_id == self.tok.eos_token_id:
return True
# Check for custom stop strings
if cfg.stop:
for stop_str in cfg.stop:
if stop_str in generated_text:
return True
return False
def generate_with_cache(self, prompt: str, session_id: str = "default",
cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
"""Generate text with persistent KV cache using manual generation loop"""
# Set defaults for None values
if cfg.temperature is None:
cfg.temperature = 0.7 # Default temperature
if cfg.top_p is None:
cfg.top_p = 0.95 # Default top_p
if cfg.max_tokens is None:
cfg.max_tokens = 500 # Default max tokens
session = self.get_session(session_id)
# Tokenize input
inputs = self.tok(prompt, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(self.model.device)
attention_mask = inputs.attention_mask.to(self.model.device)
generated_tokens = []
max_new_tokens = cfg.max_tokens
# Check if we need to invalidate cache with proper token validation
if session.should_invalidate(prompt, self.tok):
session.invalidate_cache()
print(f"[KV_CACHE] Cache invalidated for session {session_id}")
# Determine if we need prefill phase
if session.past_key_values is None:
print(f"[KV_CACHE] Running prefill phase for {len(input_ids[0])} tokens")
# Prefill phase: process full prompt
logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg)
# Update session cache
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
# Sample first token
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
# Update input_ids and attention_mask for decode phase
current_length = input_ids.shape[1]
else:
print(f"[KV_CACHE] Using cached KV state, context length: {session.context_length}")
# Use cached state
past_key_values = session.past_key_values
current_length = session.context_length
# For cached state, we still need to process the new part of the prompt if it exists
cached_length = len(session.cached_input_ids)
if input_ids.shape[1] > cached_length:
print(f"[KV_CACHE] Processing {input_ids.shape[1] - cached_length} new tokens (incremental prefill)")
new_tokens = input_ids[:, cached_length:]
# Process only the new tokens with existing KV cache
# Create attention mask that covers both cached and new tokens
total_length = cached_length + new_tokens.shape[1]
extended_attention = self.torch.ones((1, total_length), device=self.model.device)
logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg)
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
current_length = input_ids.shape[1]
else:
# No new tokens to process - get initial logits for generation
# Use a forward pass with the last token to get proper logits distribution
last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device)
total_length = current_length + 1
next_attention = self.torch.ones((1, total_length), device=self.model.device)
logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg)
# Sample first new token
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
# Decode phase: generate tokens one by one with optimized attention management
for step in range(max_new_tokens - 1): # -1 because we already generated first token
# Check stop conditions
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
if self._should_stop(next_token_id, generated_text, cfg):
break
# Prepare inputs for next step - only pass the new token
next_input = self.torch.tensor([[next_token_id]], device=self.model.device)
total_length = current_length + len(generated_tokens) + 1
next_attention = self.torch.ones((1, total_length), device=self.model.device)
# Generate next token using cached KV states
logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg)
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
# Update session with final state - include all processed tokens
final_input_ids = input_ids[0].tolist() + generated_tokens
session.update_cache(past_key_values, final_input_ids, prompt + self.tok.decode(generated_tokens, skip_special_tokens=True))
# Decode generated tokens
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
print(f"[KV_CACHE] Generated {len(generated_tokens)} tokens with KV cache")
return generated_text
def stream_with_cache(self, prompt: str, session_id: str = "default",
cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
"""Stream generation with persistent KV cache"""
session = self.get_session(session_id)
# Tokenize input
inputs = self.tok(prompt, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(self.model.device)
attention_mask = inputs.attention_mask.to(self.model.device)
generated_tokens = []
max_new_tokens = cfg.max_tokens
# Check if we need to invalidate cache with proper token validation
if session.should_invalidate(prompt, self.tok):
session.invalidate_cache()
print(f"[KV_CACHE] Cache invalidated for session {session_id}")
# Determine if we need prefill phase
if session.past_key_values is None:
print(f"[KV_CACHE] Streaming prefill phase for {len(input_ids[0])} tokens")
# Prefill phase: process full prompt
logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg)
# Update session cache
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
# Sample first token
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
current_length = input_ids.shape[1]
# Yield first token
first_text = self.tok.decode([next_token_id], skip_special_tokens=True)
if first_text:
yield first_text
else:
print(f"[KV_CACHE] Streaming with cached KV state, context length: {session.context_length}")
# Use cached state - optimized logic similar to generate_with_cache
past_key_values = session.past_key_values
current_length = session.context_length
# Handle new tokens in prompt with incremental prefill
cached_length = len(session.cached_input_ids)
if input_ids.shape[1] > cached_length:
print(f"[KV_CACHE] Streaming incremental prefill for {input_ids.shape[1] - cached_length} new tokens")
new_tokens = input_ids[:, cached_length:]
# Process only new tokens with existing KV cache
total_length = cached_length + new_tokens.shape[1]
extended_attention = self.torch.ones((1, total_length), device=self.model.device)
logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg)
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
current_length = input_ids.shape[1]
else:
# No new tokens - use last cached token for initial generation
last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device)
total_length = current_length + 1
next_attention = self.torch.ones((1, total_length), device=self.model.device)
logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg)
# Sample first token
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
# Yield first token
first_text = self.tok.decode([next_token_id], skip_special_tokens=True)
if first_text:
yield first_text
# Decode phase: stream tokens one by one with optimized KV cache usage
for step in range(max_new_tokens - 1):
# Check stop conditions
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
if self._should_stop(next_token_id, generated_text, cfg):
break
# Prepare inputs for next step - efficient single token processing
next_input = self.torch.tensor([[next_token_id]], device=self.model.device)
total_length = current_length + len(generated_tokens) + 1
next_attention = self.torch.ones((1, total_length), device=self.model.device)
# Generate next token using cached KV states
logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg)
next_token_id = self._sample_token(logits, cfg)
generated_tokens.append(next_token_id)
# Yield the new token immediately
token_text = self.tok.decode([next_token_id], skip_special_tokens=True)
if token_text:
yield token_text
# Update session with final state - include all processed tokens
final_input_ids = input_ids[0].tolist() + generated_tokens
full_generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
session.update_cache(past_key_values, final_input_ids, prompt + full_generated_text)
print(f"[KV_CACHE] Streamed {len(generated_tokens)} tokens with persistent KV cache")
def clear_session_cache(self, session_id: str = "default") -> None:
"""Clear KV cache for a specific session"""
session = self.get_session(session_id)
session.invalidate_cache()
print(f"[DEBUG] Cleared cache for session {session_id}")
def get_session_info(self, session_id: str = "default") -> Dict[str, Any]:
"""Get information about a chat session"""
session = self.get_session(session_id)
return session.get_cache_info()
def add_conversation_turn(self, user_message: str, assistant_message: str,
session_id: str = "default") -> None:
"""Add a complete conversation turn to the session history"""
session = self.get_session(session_id)
session.add_message("user", user_message)
session.add_message("assistant", assistant_message)
def get_model_info(self) -> Dict[str, Any]:
"""Get comprehensive information about the loaded model"""
try:
# Basic model info
info = {
"model_name": getattr(self.model.config, 'name_or_path', 'unknown'),
"model_type": self.model.config.model_type,
"vocab_size": self.model.config.vocab_size,
"device": str(self.device),
"dtype": str(self.model.dtype),
"supports_kv_cache": True,
"max_position_embeddings": getattr(self.model.config, 'max_position_embeddings', 'unknown'),
"torch_compile_enabled": hasattr(self.model, '_orig_mod')
}
# Memory info
if self.torch.cuda.is_available() and str(self.device) != 'cpu':
info.update({
"gpu_memory_allocated": f"{self.torch.cuda.memory_allocated() / 1024**3:.2f} GB",
"gpu_memory_reserved": f"{self.torch.cuda.memory_reserved() / 1024**3:.2f} GB",
"gpu_memory_total": f"{self.torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB"
})
# Model parameters
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
info.update({
"total_parameters": f"{total_params:,}",
"trainable_parameters": f"{trainable_params:,}",
"model_size_mb": f"{total_params * 4 / 1024**2:.2f} MB" # Assuming float32
})
return info
except Exception as e:
return {"error": f"Could not get model info: {e}", "supports_kv_cache": True}
def get_kv_cache_stats(self) -> Dict[str, Any]:
"""Get KV cache statistics across all sessions"""
try:
sessions = self.session_manager.get_all_sessions()
stats = {
"total_sessions": len(sessions),
"active_sessions": 0,
"total_cached_tokens": 0,
"memory_usage_estimate": "0 MB"
}
for session_id in sessions:
session = self.session_manager.get_session(session_id)
if session.past_key_values is not None:
stats["active_sessions"] += 1
stats["total_cached_tokens"] += session.context_length
# Rough estimate of KV cache memory usage
# Each token in KV cache roughly uses: hidden_size * num_layers * 2 (key + value) * 4 bytes (float32)
if stats["total_cached_tokens"] > 0:
try:
hidden_size = self.model.config.hidden_size
num_layers = self.model.config.num_hidden_layers
memory_bytes = stats["total_cached_tokens"] * hidden_size * num_layers * 2 * 4
stats["memory_usage_estimate"] = f"{memory_bytes / 1024**2:.2f} MB"
except:
pass
return stats
except Exception as e:
return {"error": f"Could not get KV cache stats: {e}"}
def warm_up_model(self, test_prompt: str = "Hello") -> Dict[str, float]:
"""Warm up the model and measure performance metrics"""
try:
print("[KV_CACHE] Warming up model...")
start_time = time.time()
# Simple generation to warm up CUDA kernels
cfg = GenerateConfig(max_tokens=5, temperature=0.0)
_ = self.generate(test_prompt, cfg=cfg, session_id="warmup")
warmup_time = time.time() - start_time
# Clean up warmup session
self.clear_session_cache("warmup")
return {
"warmup_time": warmup_time,
"status": "success"
}
except Exception as e:
return {
"warmup_time": 0.0,
"status": "failed",
"error": str(e)
}
class HFTransformersLoader:
name = "hf"
def can_load(self, source: str, **kwargs: Any) -> bool:
# Accept HF repo-id or local dir with config.json (covers .safetensors)
if "/" in source and not os.path.exists(source): # repo-id like "microsoft/DialoGPT-medium"
return True
# Check if it's a directory with config.json
if os.path.isdir(source) and os.path.exists(os.path.join(source, "config.json")):
return True
# If it's a single .safetensors file, look for config.json in the same directory
if source.lower().endswith('.safetensors'):
parent_dir = os.path.dirname(source)
return os.path.exists(os.path.join(parent_dir, "config.json"))
# Also support .bin files (PyTorch checkpoints)
if source.lower().endswith('.bin'):
parent_dir = os.path.dirname(source)
return os.path.exists(os.path.join(parent_dir, "config.json"))
return False
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _HFUnified(source, **kwargs)