first commit
This commit is contained in:
672
llm_runtime/loaders/transformers_loader.py
Normal file
672
llm_runtime/loaders/transformers_loader.py
Normal file
@@ -0,0 +1,672 @@
|
||||
from typing import Any, Iterator, List, Optional, Tuple, Dict
|
||||
import os
|
||||
import time
|
||||
from llm_runtime.types import UnifiedModel, GenerateConfig
|
||||
from llm_runtime.util_chat import apply_chat_template
|
||||
from llm_runtime.chat_session import ChatSession, get_session_manager
|
||||
from llm_runtime.device_utils import device_for_hf
|
||||
|
||||
class _HFUnified:
|
||||
def __init__(self, src: str, **kwargs: Any):
|
||||
print(f"[HF_DEBUG] _HFUnified.__init__() called with src='{src}', kwargs={kwargs}")
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
print("[HF_DEBUG] Successfully imported torch and transformers")
|
||||
except ImportError:
|
||||
print("[HF_DEBUG] Failed to import torch or transformers")
|
||||
raise ImportError("transformers, torch, and accelerate are required for HF models. Install with: pip install transformers torch accelerate safetensors")
|
||||
|
||||
self.torch = torch
|
||||
self.TextIteratorStreamer = TextIteratorStreamer
|
||||
|
||||
# Normalize device for HuggingFace
|
||||
self.device = device_for_hf(kwargs.get("device"))
|
||||
print(f"Loading HF model from: {src}")
|
||||
print(f"Using device: {self.device}")
|
||||
|
||||
# Load tokenizer
|
||||
self.tok = AutoTokenizer.from_pretrained(
|
||||
src,
|
||||
use_fast=True,
|
||||
trust_remote_code=kwargs.get("trust_remote_code", True),
|
||||
token=kwargs.get("token")
|
||||
)
|
||||
|
||||
# Prepare quantization config if requested
|
||||
quantization_config = self._prepare_quantization_config(kwargs)
|
||||
|
||||
# Prepare device mapping
|
||||
device_map, max_memory = self._prepare_device_mapping(kwargs, self.device)
|
||||
|
||||
# Load model with advanced options
|
||||
load_kwargs = {
|
||||
"torch_dtype": kwargs.get("torch_dtype", "auto"),
|
||||
"device_map": device_map,
|
||||
"trust_remote_code": kwargs.get("trust_remote_code", True),
|
||||
"low_cpu_mem_usage": True,
|
||||
"token": kwargs.get("token")
|
||||
}
|
||||
|
||||
if quantization_config:
|
||||
load_kwargs["quantization_config"] = quantization_config
|
||||
print(f"Using quantization: {kwargs.get('quantization', 'none')}")
|
||||
# When quantization is enabled, force device_map to "auto" to avoid device format conflicts
|
||||
load_kwargs["device_map"] = "auto"
|
||||
print(f"[DEBUG] Forcing device_map='auto' for quantization compatibility")
|
||||
|
||||
if max_memory:
|
||||
load_kwargs["max_memory"] = max_memory
|
||||
print(f"Memory limits: {max_memory}")
|
||||
|
||||
if kwargs.get("offload_folder"):
|
||||
load_kwargs["offload_folder"] = kwargs.get("offload_folder")
|
||||
print(f"Offloading to: {kwargs.get('offload_folder')}")
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(src, **load_kwargs)
|
||||
|
||||
# Set pad token if not present
|
||||
if self.tok.pad_token is None:
|
||||
self.tok.pad_token = self.tok.eos_token
|
||||
|
||||
# Initialize chat session support with context from kwargs
|
||||
self.session_manager = get_session_manager()
|
||||
self.current_session = None
|
||||
|
||||
# Store context size for session creation
|
||||
self.n_ctx = kwargs.get('n_ctx', 4096)
|
||||
|
||||
def _prepare_quantization_config(self, kwargs: Any):
|
||||
"""Prepare quantization configuration"""
|
||||
quantization = kwargs.get("quantization", "none")
|
||||
|
||||
if quantization == "none":
|
||||
return None
|
||||
|
||||
try:
|
||||
from transformers import BitsAndBytesConfig
|
||||
except ImportError:
|
||||
print("Warning: BitsAndBytesConfig not available, skipping quantization")
|
||||
return None
|
||||
|
||||
if quantization == "4bit":
|
||||
return BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_compute_dtype=self.torch.bfloat16,
|
||||
)
|
||||
elif quantization == "8bit":
|
||||
return BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
else:
|
||||
print(f"Warning: Unknown quantization type '{quantization}', skipping")
|
||||
return None
|
||||
|
||||
def _prepare_device_mapping(self, kwargs: Any, hf_device: str):
|
||||
"""Prepare device mapping and memory limits"""
|
||||
device_strategy = kwargs.get("device_strategy", "auto")
|
||||
gpu_memory_limit = kwargs.get("gpu_memory_limit", None)
|
||||
|
||||
device_map = "auto"
|
||||
max_memory = None
|
||||
|
||||
if device_strategy == "force_gpu":
|
||||
device_map = {"": hf_device}
|
||||
elif device_strategy == "balanced_split":
|
||||
# Use memory limits for balanced CPU/GPU split
|
||||
if gpu_memory_limit:
|
||||
# For quantization, use integer device index; for non-quantization, use string
|
||||
if kwargs.get("quantization", "none") != "none":
|
||||
gpu_device = 0 if hf_device.startswith("cuda:") else 0
|
||||
else:
|
||||
gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0"
|
||||
max_memory = {
|
||||
gpu_device: f"{gpu_memory_limit}GiB",
|
||||
"cpu": "48GiB", # Large CPU limit
|
||||
}
|
||||
device_map = "auto"
|
||||
elif device_strategy == "cpu_only":
|
||||
device_map = {"": "cpu"}
|
||||
else: # auto
|
||||
if gpu_memory_limit:
|
||||
# For quantization, use integer device index; for non-quantization, use string
|
||||
if kwargs.get("quantization", "none") != "none":
|
||||
gpu_device = 0 if hf_device.startswith("cuda:") else 0
|
||||
else:
|
||||
gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0"
|
||||
max_memory = {
|
||||
gpu_device: f"{gpu_memory_limit}GiB",
|
||||
"cpu": "32GiB",
|
||||
}
|
||||
device_map = "auto"
|
||||
|
||||
return device_map, max_memory
|
||||
|
||||
def _build_eos_ids(self, stop) -> Optional[List[int]]:
|
||||
"""Convert stop strings to token IDs"""
|
||||
if not stop:
|
||||
return None
|
||||
|
||||
ids = []
|
||||
for s in stop:
|
||||
# Handle single character stops
|
||||
if len(s) == 1:
|
||||
tid = self.tok.convert_tokens_to_ids(s)
|
||||
if tid is not None:
|
||||
ids.append(tid)
|
||||
return ids or None
|
||||
|
||||
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
|
||||
"""Generate text - uses KV cache if session_id provided, otherwise fallback to standard generation"""
|
||||
# Set defaults for None values
|
||||
if cfg.temperature is None:
|
||||
cfg.temperature = 0.7 # Default temperature
|
||||
if cfg.top_p is None:
|
||||
cfg.top_p = 0.95 # Default top_p
|
||||
if cfg.max_tokens is None:
|
||||
cfg.max_tokens = 500 # Default max tokens
|
||||
|
||||
session_id = kwargs.get('session_id', 'default') # Use 'default' session if none specified
|
||||
if session_id:
|
||||
return self.generate_with_cache(prompt, session_id, cfg, **kwargs)
|
||||
|
||||
# Fallback to standard generation without KV cache
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
inputs = self.tok(prompt, return_tensors="pt").to(self.model.device)
|
||||
|
||||
class MultiStringStop(StoppingCriteria):
|
||||
def __init__(self, toks, stops):
|
||||
self.toks, self.stops = toks, stops or []
|
||||
|
||||
def __call__(self, input_ids, scores, **_):
|
||||
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
|
||||
return any(s in text for s in self.stops)
|
||||
|
||||
with self.torch.no_grad():
|
||||
out_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=cfg.max_tokens,
|
||||
do_sample=cfg.temperature is not None and cfg.temperature > 0.0,
|
||||
temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None,
|
||||
top_p=cfg.top_p,
|
||||
eos_token_id=self._build_eos_ids(cfg.stop),
|
||||
stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None,
|
||||
pad_token_id=self.tok.pad_token_id,
|
||||
)
|
||||
|
||||
# Decode only the new tokens
|
||||
generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
return generated_text
|
||||
|
||||
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
|
||||
"""Stream text generation - uses KV cache if session_id provided, otherwise fallback to standard streaming"""
|
||||
session_id = kwargs.pop('session_id', 'default') # Remove from kwargs to avoid duplicate
|
||||
if session_id:
|
||||
yield from self.stream_with_cache(prompt, session_id, cfg, **kwargs)
|
||||
return
|
||||
|
||||
# Fallback to standard streaming without KV cache
|
||||
import threading
|
||||
|
||||
enc = self.tok(prompt, return_tensors="pt").to(self.model.device)
|
||||
streamer = self.TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
def _worker():
|
||||
with self.torch.no_grad():
|
||||
self.model.generate(
|
||||
**enc,
|
||||
max_new_tokens=cfg.max_tokens,
|
||||
do_sample=cfg.temperature is not None and cfg.temperature > 0.0,
|
||||
temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None,
|
||||
top_p=cfg.top_p,
|
||||
eos_token_id=self._build_eos_ids(cfg.stop),
|
||||
streamer=streamer,
|
||||
pad_token_id=self.tok.pad_token_id,
|
||||
)
|
||||
|
||||
t = threading.Thread(target=_worker)
|
||||
t.start()
|
||||
|
||||
for text in streamer:
|
||||
yield text
|
||||
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
return self.tok.encode(text, add_special_tokens=False)
|
||||
|
||||
def detokenize(self, ids: List[int]) -> str:
|
||||
return self.tok.decode(ids, skip_special_tokens=True)
|
||||
|
||||
def get_session(self, session_id: str = "default") -> ChatSession:
|
||||
"""Get or create a chat session for KV caching"""
|
||||
# Create session config with the correct context size
|
||||
from llm_runtime.chat_session import ChatSessionConfig
|
||||
session_config = ChatSessionConfig(max_context_length=self.n_ctx)
|
||||
return self.session_manager.get_session(session_id, config=session_config)
|
||||
|
||||
def _prefill_phase(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
|
||||
cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
|
||||
"""Prefill phase: process the full prompt and return logits + past_key_values"""
|
||||
with self.torch.no_grad():
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
return_dict=True
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
def _prefill_incremental(self, new_input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
|
||||
past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
|
||||
"""Incremental prefill: process only new tokens with existing KV cache"""
|
||||
with self.torch.no_grad():
|
||||
outputs = self.model(
|
||||
input_ids=new_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
def _decode_step(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor',
|
||||
past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]:
|
||||
"""Single decode step: generate next token with KV cache"""
|
||||
with self.torch.no_grad():
|
||||
outputs = self.model(
|
||||
input_ids=input_ids, # Should be shape [1, 1] for single token
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
def _sample_token(self, logits: 'torch.Tensor', cfg: GenerateConfig) -> int:
|
||||
"""Sample next token from logits based on generation config with optimized sampling"""
|
||||
# Get logits for the last token
|
||||
next_token_logits = logits[0, -1, :]
|
||||
|
||||
# Apply temperature scaling
|
||||
if cfg.temperature is not None and cfg.temperature != 1.0 and cfg.temperature > 0:
|
||||
next_token_logits = next_token_logits / cfg.temperature
|
||||
|
||||
# Apply top-p (nucleus) sampling if specified
|
||||
if cfg.temperature is not None and cfg.temperature > 0.0 and cfg.top_p is not None and cfg.top_p < 1.0:
|
||||
sorted_logits, sorted_indices = self.torch.sort(next_token_logits, descending=True)
|
||||
cumulative_probs = self.torch.cumsum(self.torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > cfg.top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
|
||||
next_token_logits[indices_to_remove] = float('-inf')
|
||||
|
||||
if cfg.temperature is not None and cfg.temperature > 0.0:
|
||||
# Sample from distribution
|
||||
probs = self.torch.softmax(next_token_logits, dim=-1)
|
||||
next_token = self.torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
# Greedy sampling (deterministic)
|
||||
next_token = self.torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
||||
|
||||
return next_token.item()
|
||||
|
||||
def _should_stop(self, token_id: int, generated_text: str, cfg: GenerateConfig) -> bool:
|
||||
"""Check if generation should stop"""
|
||||
# Check for EOS token
|
||||
if token_id == self.tok.eos_token_id:
|
||||
return True
|
||||
|
||||
# Check for custom stop strings
|
||||
if cfg.stop:
|
||||
for stop_str in cfg.stop:
|
||||
if stop_str in generated_text:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def generate_with_cache(self, prompt: str, session_id: str = "default",
|
||||
cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
|
||||
"""Generate text with persistent KV cache using manual generation loop"""
|
||||
# Set defaults for None values
|
||||
if cfg.temperature is None:
|
||||
cfg.temperature = 0.7 # Default temperature
|
||||
if cfg.top_p is None:
|
||||
cfg.top_p = 0.95 # Default top_p
|
||||
if cfg.max_tokens is None:
|
||||
cfg.max_tokens = 500 # Default max tokens
|
||||
|
||||
session = self.get_session(session_id)
|
||||
|
||||
# Tokenize input
|
||||
inputs = self.tok(prompt, return_tensors="pt", padding=True)
|
||||
input_ids = inputs.input_ids.to(self.model.device)
|
||||
attention_mask = inputs.attention_mask.to(self.model.device)
|
||||
|
||||
generated_tokens = []
|
||||
max_new_tokens = cfg.max_tokens
|
||||
|
||||
# Check if we need to invalidate cache with proper token validation
|
||||
if session.should_invalidate(prompt, self.tok):
|
||||
session.invalidate_cache()
|
||||
print(f"[KV_CACHE] Cache invalidated for session {session_id}")
|
||||
|
||||
# Determine if we need prefill phase
|
||||
if session.past_key_values is None:
|
||||
print(f"[KV_CACHE] Running prefill phase for {len(input_ids[0])} tokens")
|
||||
# Prefill phase: process full prompt
|
||||
logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg)
|
||||
|
||||
# Update session cache
|
||||
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
|
||||
|
||||
# Sample first token
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
|
||||
# Update input_ids and attention_mask for decode phase
|
||||
current_length = input_ids.shape[1]
|
||||
else:
|
||||
print(f"[KV_CACHE] Using cached KV state, context length: {session.context_length}")
|
||||
# Use cached state
|
||||
past_key_values = session.past_key_values
|
||||
current_length = session.context_length
|
||||
|
||||
# For cached state, we still need to process the new part of the prompt if it exists
|
||||
cached_length = len(session.cached_input_ids)
|
||||
if input_ids.shape[1] > cached_length:
|
||||
print(f"[KV_CACHE] Processing {input_ids.shape[1] - cached_length} new tokens (incremental prefill)")
|
||||
new_tokens = input_ids[:, cached_length:]
|
||||
|
||||
# Process only the new tokens with existing KV cache
|
||||
# Create attention mask that covers both cached and new tokens
|
||||
total_length = cached_length + new_tokens.shape[1]
|
||||
extended_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg)
|
||||
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
|
||||
current_length = input_ids.shape[1]
|
||||
else:
|
||||
# No new tokens to process - get initial logits for generation
|
||||
# Use a forward pass with the last token to get proper logits distribution
|
||||
last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device)
|
||||
total_length = current_length + 1
|
||||
next_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg)
|
||||
|
||||
# Sample first new token
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
|
||||
# Decode phase: generate tokens one by one with optimized attention management
|
||||
for step in range(max_new_tokens - 1): # -1 because we already generated first token
|
||||
# Check stop conditions
|
||||
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
|
||||
if self._should_stop(next_token_id, generated_text, cfg):
|
||||
break
|
||||
|
||||
# Prepare inputs for next step - only pass the new token
|
||||
next_input = self.torch.tensor([[next_token_id]], device=self.model.device)
|
||||
total_length = current_length + len(generated_tokens) + 1
|
||||
next_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
# Generate next token using cached KV states
|
||||
logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg)
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
|
||||
# Update session with final state - include all processed tokens
|
||||
final_input_ids = input_ids[0].tolist() + generated_tokens
|
||||
session.update_cache(past_key_values, final_input_ids, prompt + self.tok.decode(generated_tokens, skip_special_tokens=True))
|
||||
|
||||
# Decode generated tokens
|
||||
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
|
||||
print(f"[KV_CACHE] Generated {len(generated_tokens)} tokens with KV cache")
|
||||
|
||||
return generated_text
|
||||
|
||||
def stream_with_cache(self, prompt: str, session_id: str = "default",
|
||||
cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
|
||||
"""Stream generation with persistent KV cache"""
|
||||
session = self.get_session(session_id)
|
||||
|
||||
# Tokenize input
|
||||
inputs = self.tok(prompt, return_tensors="pt", padding=True)
|
||||
input_ids = inputs.input_ids.to(self.model.device)
|
||||
attention_mask = inputs.attention_mask.to(self.model.device)
|
||||
|
||||
generated_tokens = []
|
||||
max_new_tokens = cfg.max_tokens
|
||||
|
||||
# Check if we need to invalidate cache with proper token validation
|
||||
if session.should_invalidate(prompt, self.tok):
|
||||
session.invalidate_cache()
|
||||
print(f"[KV_CACHE] Cache invalidated for session {session_id}")
|
||||
|
||||
# Determine if we need prefill phase
|
||||
if session.past_key_values is None:
|
||||
print(f"[KV_CACHE] Streaming prefill phase for {len(input_ids[0])} tokens")
|
||||
# Prefill phase: process full prompt
|
||||
logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg)
|
||||
|
||||
# Update session cache
|
||||
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
|
||||
|
||||
# Sample first token
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
current_length = input_ids.shape[1]
|
||||
|
||||
# Yield first token
|
||||
first_text = self.tok.decode([next_token_id], skip_special_tokens=True)
|
||||
if first_text:
|
||||
yield first_text
|
||||
else:
|
||||
print(f"[KV_CACHE] Streaming with cached KV state, context length: {session.context_length}")
|
||||
# Use cached state - optimized logic similar to generate_with_cache
|
||||
past_key_values = session.past_key_values
|
||||
current_length = session.context_length
|
||||
|
||||
# Handle new tokens in prompt with incremental prefill
|
||||
cached_length = len(session.cached_input_ids)
|
||||
if input_ids.shape[1] > cached_length:
|
||||
print(f"[KV_CACHE] Streaming incremental prefill for {input_ids.shape[1] - cached_length} new tokens")
|
||||
new_tokens = input_ids[:, cached_length:]
|
||||
|
||||
# Process only new tokens with existing KV cache
|
||||
total_length = cached_length + new_tokens.shape[1]
|
||||
extended_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg)
|
||||
session.update_cache(past_key_values, input_ids[0].tolist(), prompt)
|
||||
current_length = input_ids.shape[1]
|
||||
else:
|
||||
# No new tokens - use last cached token for initial generation
|
||||
last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device)
|
||||
total_length = current_length + 1
|
||||
next_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg)
|
||||
|
||||
# Sample first token
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
|
||||
# Yield first token
|
||||
first_text = self.tok.decode([next_token_id], skip_special_tokens=True)
|
||||
if first_text:
|
||||
yield first_text
|
||||
|
||||
# Decode phase: stream tokens one by one with optimized KV cache usage
|
||||
for step in range(max_new_tokens - 1):
|
||||
# Check stop conditions
|
||||
generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
|
||||
if self._should_stop(next_token_id, generated_text, cfg):
|
||||
break
|
||||
|
||||
# Prepare inputs for next step - efficient single token processing
|
||||
next_input = self.torch.tensor([[next_token_id]], device=self.model.device)
|
||||
total_length = current_length + len(generated_tokens) + 1
|
||||
next_attention = self.torch.ones((1, total_length), device=self.model.device)
|
||||
|
||||
# Generate next token using cached KV states
|
||||
logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg)
|
||||
next_token_id = self._sample_token(logits, cfg)
|
||||
generated_tokens.append(next_token_id)
|
||||
|
||||
# Yield the new token immediately
|
||||
token_text = self.tok.decode([next_token_id], skip_special_tokens=True)
|
||||
if token_text:
|
||||
yield token_text
|
||||
|
||||
# Update session with final state - include all processed tokens
|
||||
final_input_ids = input_ids[0].tolist() + generated_tokens
|
||||
full_generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True)
|
||||
session.update_cache(past_key_values, final_input_ids, prompt + full_generated_text)
|
||||
|
||||
print(f"[KV_CACHE] Streamed {len(generated_tokens)} tokens with persistent KV cache")
|
||||
|
||||
def clear_session_cache(self, session_id: str = "default") -> None:
|
||||
"""Clear KV cache for a specific session"""
|
||||
session = self.get_session(session_id)
|
||||
session.invalidate_cache()
|
||||
print(f"[DEBUG] Cleared cache for session {session_id}")
|
||||
|
||||
def get_session_info(self, session_id: str = "default") -> Dict[str, Any]:
|
||||
"""Get information about a chat session"""
|
||||
session = self.get_session(session_id)
|
||||
return session.get_cache_info()
|
||||
|
||||
def add_conversation_turn(self, user_message: str, assistant_message: str,
|
||||
session_id: str = "default") -> None:
|
||||
"""Add a complete conversation turn to the session history"""
|
||||
session = self.get_session(session_id)
|
||||
session.add_message("user", user_message)
|
||||
session.add_message("assistant", assistant_message)
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive information about the loaded model"""
|
||||
try:
|
||||
# Basic model info
|
||||
info = {
|
||||
"model_name": getattr(self.model.config, 'name_or_path', 'unknown'),
|
||||
"model_type": self.model.config.model_type,
|
||||
"vocab_size": self.model.config.vocab_size,
|
||||
"device": str(self.device),
|
||||
"dtype": str(self.model.dtype),
|
||||
"supports_kv_cache": True,
|
||||
"max_position_embeddings": getattr(self.model.config, 'max_position_embeddings', 'unknown'),
|
||||
"torch_compile_enabled": hasattr(self.model, '_orig_mod')
|
||||
}
|
||||
|
||||
# Memory info
|
||||
if self.torch.cuda.is_available() and str(self.device) != 'cpu':
|
||||
info.update({
|
||||
"gpu_memory_allocated": f"{self.torch.cuda.memory_allocated() / 1024**3:.2f} GB",
|
||||
"gpu_memory_reserved": f"{self.torch.cuda.memory_reserved() / 1024**3:.2f} GB",
|
||||
"gpu_memory_total": f"{self.torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB"
|
||||
})
|
||||
|
||||
# Model parameters
|
||||
total_params = sum(p.numel() for p in self.model.parameters())
|
||||
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
info.update({
|
||||
"total_parameters": f"{total_params:,}",
|
||||
"trainable_parameters": f"{trainable_params:,}",
|
||||
"model_size_mb": f"{total_params * 4 / 1024**2:.2f} MB" # Assuming float32
|
||||
})
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
return {"error": f"Could not get model info: {e}", "supports_kv_cache": True}
|
||||
|
||||
def get_kv_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get KV cache statistics across all sessions"""
|
||||
try:
|
||||
sessions = self.session_manager.get_all_sessions()
|
||||
stats = {
|
||||
"total_sessions": len(sessions),
|
||||
"active_sessions": 0,
|
||||
"total_cached_tokens": 0,
|
||||
"memory_usage_estimate": "0 MB"
|
||||
}
|
||||
|
||||
for session_id in sessions:
|
||||
session = self.session_manager.get_session(session_id)
|
||||
if session.past_key_values is not None:
|
||||
stats["active_sessions"] += 1
|
||||
stats["total_cached_tokens"] += session.context_length
|
||||
|
||||
# Rough estimate of KV cache memory usage
|
||||
# Each token in KV cache roughly uses: hidden_size * num_layers * 2 (key + value) * 4 bytes (float32)
|
||||
if stats["total_cached_tokens"] > 0:
|
||||
try:
|
||||
hidden_size = self.model.config.hidden_size
|
||||
num_layers = self.model.config.num_hidden_layers
|
||||
memory_bytes = stats["total_cached_tokens"] * hidden_size * num_layers * 2 * 4
|
||||
stats["memory_usage_estimate"] = f"{memory_bytes / 1024**2:.2f} MB"
|
||||
except:
|
||||
pass
|
||||
|
||||
return stats
|
||||
except Exception as e:
|
||||
return {"error": f"Could not get KV cache stats: {e}"}
|
||||
|
||||
def warm_up_model(self, test_prompt: str = "Hello") -> Dict[str, float]:
|
||||
"""Warm up the model and measure performance metrics"""
|
||||
try:
|
||||
print("[KV_CACHE] Warming up model...")
|
||||
start_time = time.time()
|
||||
|
||||
# Simple generation to warm up CUDA kernels
|
||||
cfg = GenerateConfig(max_tokens=5, temperature=0.0)
|
||||
_ = self.generate(test_prompt, cfg=cfg, session_id="warmup")
|
||||
|
||||
warmup_time = time.time() - start_time
|
||||
|
||||
# Clean up warmup session
|
||||
self.clear_session_cache("warmup")
|
||||
|
||||
return {
|
||||
"warmup_time": warmup_time,
|
||||
"status": "success"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"warmup_time": 0.0,
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
class HFTransformersLoader:
|
||||
name = "hf"
|
||||
|
||||
def can_load(self, source: str, **kwargs: Any) -> bool:
|
||||
# Accept HF repo-id or local dir with config.json (covers .safetensors)
|
||||
if "/" in source and not os.path.exists(source): # repo-id like "microsoft/DialoGPT-medium"
|
||||
return True
|
||||
|
||||
# Check if it's a directory with config.json
|
||||
if os.path.isdir(source) and os.path.exists(os.path.join(source, "config.json")):
|
||||
return True
|
||||
|
||||
# If it's a single .safetensors file, look for config.json in the same directory
|
||||
if source.lower().endswith('.safetensors'):
|
||||
parent_dir = os.path.dirname(source)
|
||||
return os.path.exists(os.path.join(parent_dir, "config.json"))
|
||||
|
||||
# Also support .bin files (PyTorch checkpoints)
|
||||
if source.lower().endswith('.bin'):
|
||||
parent_dir = os.path.dirname(source)
|
||||
return os.path.exists(os.path.join(parent_dir, "config.json"))
|
||||
|
||||
return False
|
||||
|
||||
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
|
||||
return _HFUnified(source, **kwargs)
|
||||
Reference in New Issue
Block a user