from typing import Any, Iterator, List, Optional, Tuple, Dict import os import time from llm_runtime.types import UnifiedModel, GenerateConfig from llm_runtime.util_chat import apply_chat_template from llm_runtime.chat_session import ChatSession, get_session_manager from llm_runtime.device_utils import device_for_hf class _HFUnified: def __init__(self, src: str, **kwargs: Any): print(f"[HF_DEBUG] _HFUnified.__init__() called with src='{src}', kwargs={kwargs}") try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer print("[HF_DEBUG] Successfully imported torch and transformers") except ImportError: print("[HF_DEBUG] Failed to import torch or transformers") raise ImportError("transformers, torch, and accelerate are required for HF models. Install with: pip install transformers torch accelerate safetensors") self.torch = torch self.TextIteratorStreamer = TextIteratorStreamer # Normalize device for HuggingFace self.device = device_for_hf(kwargs.get("device")) print(f"Loading HF model from: {src}") print(f"Using device: {self.device}") # Load tokenizer self.tok = AutoTokenizer.from_pretrained( src, use_fast=True, trust_remote_code=kwargs.get("trust_remote_code", True), token=kwargs.get("token") ) # Prepare quantization config if requested quantization_config = self._prepare_quantization_config(kwargs) # Prepare device mapping device_map, max_memory = self._prepare_device_mapping(kwargs, self.device) # Load model with advanced options load_kwargs = { "torch_dtype": kwargs.get("torch_dtype", "auto"), "device_map": device_map, "trust_remote_code": kwargs.get("trust_remote_code", True), "low_cpu_mem_usage": True, "token": kwargs.get("token") } if quantization_config: load_kwargs["quantization_config"] = quantization_config print(f"Using quantization: {kwargs.get('quantization', 'none')}") # When quantization is enabled, force device_map to "auto" to avoid device format conflicts load_kwargs["device_map"] = "auto" print(f"[DEBUG] Forcing device_map='auto' for quantization compatibility") if max_memory: load_kwargs["max_memory"] = max_memory print(f"Memory limits: {max_memory}") if kwargs.get("offload_folder"): load_kwargs["offload_folder"] = kwargs.get("offload_folder") print(f"Offloading to: {kwargs.get('offload_folder')}") self.model = AutoModelForCausalLM.from_pretrained(src, **load_kwargs) # Set pad token if not present if self.tok.pad_token is None: self.tok.pad_token = self.tok.eos_token # Initialize chat session support with context from kwargs self.session_manager = get_session_manager() self.current_session = None # Store context size for session creation self.n_ctx = kwargs.get('n_ctx', 4096) def _prepare_quantization_config(self, kwargs: Any): """Prepare quantization configuration""" quantization = kwargs.get("quantization", "none") if quantization == "none": return None try: from transformers import BitsAndBytesConfig except ImportError: print("Warning: BitsAndBytesConfig not available, skipping quantization") return None if quantization == "4bit": return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=self.torch.bfloat16, ) elif quantization == "8bit": return BitsAndBytesConfig( load_in_8bit=True, ) else: print(f"Warning: Unknown quantization type '{quantization}', skipping") return None def _prepare_device_mapping(self, kwargs: Any, hf_device: str): """Prepare device mapping and memory limits""" device_strategy = kwargs.get("device_strategy", "auto") gpu_memory_limit = kwargs.get("gpu_memory_limit", None) device_map = "auto" max_memory = None if device_strategy == "force_gpu": device_map = {"": hf_device} elif device_strategy == "balanced_split": # Use memory limits for balanced CPU/GPU split if gpu_memory_limit: # For quantization, use integer device index; for non-quantization, use string if kwargs.get("quantization", "none") != "none": gpu_device = 0 if hf_device.startswith("cuda:") else 0 else: gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0" max_memory = { gpu_device: f"{gpu_memory_limit}GiB", "cpu": "48GiB", # Large CPU limit } device_map = "auto" elif device_strategy == "cpu_only": device_map = {"": "cpu"} else: # auto if gpu_memory_limit: # For quantization, use integer device index; for non-quantization, use string if kwargs.get("quantization", "none") != "none": gpu_device = 0 if hf_device.startswith("cuda:") else 0 else: gpu_device = hf_device if hf_device.startswith("cuda:") else "cuda:0" max_memory = { gpu_device: f"{gpu_memory_limit}GiB", "cpu": "32GiB", } device_map = "auto" return device_map, max_memory def _build_eos_ids(self, stop) -> Optional[List[int]]: """Convert stop strings to token IDs""" if not stop: return None ids = [] for s in stop: # Handle single character stops if len(s) == 1: tid = self.tok.convert_tokens_to_ids(s) if tid is not None: ids.append(tid) return ids or None def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str: """Generate text - uses KV cache if session_id provided, otherwise fallback to standard generation""" # Set defaults for None values if cfg.temperature is None: cfg.temperature = 0.7 # Default temperature if cfg.top_p is None: cfg.top_p = 0.95 # Default top_p if cfg.max_tokens is None: cfg.max_tokens = 500 # Default max tokens session_id = kwargs.get('session_id', 'default') # Use 'default' session if none specified if session_id: return self.generate_with_cache(prompt, session_id, cfg, **kwargs) # Fallback to standard generation without KV cache from transformers import StoppingCriteria, StoppingCriteriaList inputs = self.tok(prompt, return_tensors="pt").to(self.model.device) class MultiStringStop(StoppingCriteria): def __init__(self, toks, stops): self.toks, self.stops = toks, stops or [] def __call__(self, input_ids, scores, **_): text = self.toks.decode(input_ids[0], skip_special_tokens=True) return any(s in text for s in self.stops) with self.torch.no_grad(): out_ids = self.model.generate( **inputs, max_new_tokens=cfg.max_tokens, do_sample=cfg.temperature is not None and cfg.temperature > 0.0, temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None, top_p=cfg.top_p, eos_token_id=self._build_eos_ids(cfg.stop), stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None, pad_token_id=self.tok.pad_token_id, ) # Decode only the new tokens generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return generated_text def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]: """Stream text generation - uses KV cache if session_id provided, otherwise fallback to standard streaming""" session_id = kwargs.pop('session_id', 'default') # Remove from kwargs to avoid duplicate if session_id: yield from self.stream_with_cache(prompt, session_id, cfg, **kwargs) return # Fallback to standard streaming without KV cache import threading enc = self.tok(prompt, return_tensors="pt").to(self.model.device) streamer = self.TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True) def _worker(): with self.torch.no_grad(): self.model.generate( **enc, max_new_tokens=cfg.max_tokens, do_sample=cfg.temperature is not None and cfg.temperature > 0.0, temperature=cfg.temperature if cfg.temperature is not None and cfg.temperature > 0.0 else None, top_p=cfg.top_p, eos_token_id=self._build_eos_ids(cfg.stop), streamer=streamer, pad_token_id=self.tok.pad_token_id, ) t = threading.Thread(target=_worker) t.start() for text in streamer: yield text def tokenize(self, text: str) -> List[int]: return self.tok.encode(text, add_special_tokens=False) def detokenize(self, ids: List[int]) -> str: return self.tok.decode(ids, skip_special_tokens=True) def get_session(self, session_id: str = "default") -> ChatSession: """Get or create a chat session for KV caching""" # Create session config with the correct context size from llm_runtime.chat_session import ChatSessionConfig session_config = ChatSessionConfig(max_context_length=self.n_ctx) return self.session_manager.get_session(session_id, config=session_config) def _prefill_phase(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]: """Prefill phase: process the full prompt and return logits + past_key_values""" with self.torch.no_grad(): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, use_cache=True, return_dict=True ) return outputs.logits, outputs.past_key_values def _prefill_incremental(self, new_input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]: """Incremental prefill: process only new tokens with existing KV cache""" with self.torch.no_grad(): outputs = self.model( input_ids=new_input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, return_dict=True ) return outputs.logits, outputs.past_key_values def _decode_step(self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', past_key_values: Any, cfg: GenerateConfig) -> Tuple['torch.Tensor', Any]: """Single decode step: generate next token with KV cache""" with self.torch.no_grad(): outputs = self.model( input_ids=input_ids, # Should be shape [1, 1] for single token attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, return_dict=True ) return outputs.logits, outputs.past_key_values def _sample_token(self, logits: 'torch.Tensor', cfg: GenerateConfig) -> int: """Sample next token from logits based on generation config with optimized sampling""" # Get logits for the last token next_token_logits = logits[0, -1, :] # Apply temperature scaling if cfg.temperature is not None and cfg.temperature != 1.0 and cfg.temperature > 0: next_token_logits = next_token_logits / cfg.temperature # Apply top-p (nucleus) sampling if specified if cfg.temperature is not None and cfg.temperature > 0.0 and cfg.top_p is not None and cfg.top_p < 1.0: sorted_logits, sorted_indices = self.torch.sort(next_token_logits, descending=True) cumulative_probs = self.torch.cumsum(self.torch.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > cfg.top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') if cfg.temperature is not None and cfg.temperature > 0.0: # Sample from distribution probs = self.torch.softmax(next_token_logits, dim=-1) next_token = self.torch.multinomial(probs, num_samples=1) else: # Greedy sampling (deterministic) next_token = self.torch.argmax(next_token_logits, dim=-1, keepdim=True) return next_token.item() def _should_stop(self, token_id: int, generated_text: str, cfg: GenerateConfig) -> bool: """Check if generation should stop""" # Check for EOS token if token_id == self.tok.eos_token_id: return True # Check for custom stop strings if cfg.stop: for stop_str in cfg.stop: if stop_str in generated_text: return True return False def generate_with_cache(self, prompt: str, session_id: str = "default", cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str: """Generate text with persistent KV cache using manual generation loop""" # Set defaults for None values if cfg.temperature is None: cfg.temperature = 0.7 # Default temperature if cfg.top_p is None: cfg.top_p = 0.95 # Default top_p if cfg.max_tokens is None: cfg.max_tokens = 500 # Default max tokens session = self.get_session(session_id) # Tokenize input inputs = self.tok(prompt, return_tensors="pt", padding=True) input_ids = inputs.input_ids.to(self.model.device) attention_mask = inputs.attention_mask.to(self.model.device) generated_tokens = [] max_new_tokens = cfg.max_tokens # Check if we need to invalidate cache with proper token validation if session.should_invalidate(prompt, self.tok): session.invalidate_cache() print(f"[KV_CACHE] Cache invalidated for session {session_id}") # Determine if we need prefill phase if session.past_key_values is None: print(f"[KV_CACHE] Running prefill phase for {len(input_ids[0])} tokens") # Prefill phase: process full prompt logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg) # Update session cache session.update_cache(past_key_values, input_ids[0].tolist(), prompt) # Sample first token next_token_id = self._sample_token(logits, cfg) generated_tokens.append(next_token_id) # Update input_ids and attention_mask for decode phase current_length = input_ids.shape[1] else: print(f"[KV_CACHE] Using cached KV state, context length: {session.context_length}") # Use cached state past_key_values = session.past_key_values current_length = session.context_length # For cached state, we still need to process the new part of the prompt if it exists cached_length = len(session.cached_input_ids) if input_ids.shape[1] > cached_length: print(f"[KV_CACHE] Processing {input_ids.shape[1] - cached_length} new tokens (incremental prefill)") new_tokens = input_ids[:, cached_length:] # Process only the new tokens with existing KV cache # Create attention mask that covers both cached and new tokens total_length = cached_length + new_tokens.shape[1] extended_attention = self.torch.ones((1, total_length), device=self.model.device) logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg) session.update_cache(past_key_values, input_ids[0].tolist(), prompt) current_length = input_ids.shape[1] else: # No new tokens to process - get initial logits for generation # Use a forward pass with the last token to get proper logits distribution last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device) total_length = current_length + 1 next_attention = self.torch.ones((1, total_length), device=self.model.device) logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg) # Sample first new token next_token_id = self._sample_token(logits, cfg) generated_tokens.append(next_token_id) # Decode phase: generate tokens one by one with optimized attention management for step in range(max_new_tokens - 1): # -1 because we already generated first token # Check stop conditions generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True) if self._should_stop(next_token_id, generated_text, cfg): break # Prepare inputs for next step - only pass the new token next_input = self.torch.tensor([[next_token_id]], device=self.model.device) total_length = current_length + len(generated_tokens) + 1 next_attention = self.torch.ones((1, total_length), device=self.model.device) # Generate next token using cached KV states logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg) next_token_id = self._sample_token(logits, cfg) generated_tokens.append(next_token_id) # Update session with final state - include all processed tokens final_input_ids = input_ids[0].tolist() + generated_tokens session.update_cache(past_key_values, final_input_ids, prompt + self.tok.decode(generated_tokens, skip_special_tokens=True)) # Decode generated tokens generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True) print(f"[KV_CACHE] Generated {len(generated_tokens)} tokens with KV cache") return generated_text def stream_with_cache(self, prompt: str, session_id: str = "default", cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]: """Stream generation with persistent KV cache""" session = self.get_session(session_id) # Tokenize input inputs = self.tok(prompt, return_tensors="pt", padding=True) input_ids = inputs.input_ids.to(self.model.device) attention_mask = inputs.attention_mask.to(self.model.device) generated_tokens = [] max_new_tokens = cfg.max_tokens # Check if we need to invalidate cache with proper token validation if session.should_invalidate(prompt, self.tok): session.invalidate_cache() print(f"[KV_CACHE] Cache invalidated for session {session_id}") # Determine if we need prefill phase if session.past_key_values is None: print(f"[KV_CACHE] Streaming prefill phase for {len(input_ids[0])} tokens") # Prefill phase: process full prompt logits, past_key_values = self._prefill_phase(input_ids, attention_mask, cfg) # Update session cache session.update_cache(past_key_values, input_ids[0].tolist(), prompt) # Sample first token next_token_id = self._sample_token(logits, cfg) generated_tokens.append(next_token_id) current_length = input_ids.shape[1] # Yield first token first_text = self.tok.decode([next_token_id], skip_special_tokens=True) if first_text: yield first_text else: print(f"[KV_CACHE] Streaming with cached KV state, context length: {session.context_length}") # Use cached state - optimized logic similar to generate_with_cache past_key_values = session.past_key_values current_length = session.context_length # Handle new tokens in prompt with incremental prefill cached_length = len(session.cached_input_ids) if input_ids.shape[1] > cached_length: print(f"[KV_CACHE] Streaming incremental prefill for {input_ids.shape[1] - cached_length} new tokens") new_tokens = input_ids[:, cached_length:] # Process only new tokens with existing KV cache total_length = cached_length + new_tokens.shape[1] extended_attention = self.torch.ones((1, total_length), device=self.model.device) logits, past_key_values = self._prefill_incremental(new_tokens, extended_attention, past_key_values, cfg) session.update_cache(past_key_values, input_ids[0].tolist(), prompt) current_length = input_ids.shape[1] else: # No new tokens - use last cached token for initial generation last_token = self.torch.tensor([[session.cached_input_ids[-1]]], device=self.model.device) total_length = current_length + 1 next_attention = self.torch.ones((1, total_length), device=self.model.device) logits, past_key_values = self._decode_step(last_token, next_attention, past_key_values, cfg) # Sample first token next_token_id = self._sample_token(logits, cfg) generated_tokens.append(next_token_id) # Yield first token first_text = self.tok.decode([next_token_id], skip_special_tokens=True) if first_text: yield first_text # Decode phase: stream tokens one by one with optimized KV cache usage for step in range(max_new_tokens - 1): # Check stop conditions generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True) if self._should_stop(next_token_id, generated_text, cfg): break # Prepare inputs for next step - efficient single token processing next_input = self.torch.tensor([[next_token_id]], device=self.model.device) total_length = current_length + len(generated_tokens) + 1 next_attention = self.torch.ones((1, total_length), device=self.model.device) # Generate next token using cached KV states logits, past_key_values = self._decode_step(next_input, next_attention, past_key_values, cfg) next_token_id = self._sample_token(logits, cfg) generated_tokens.append(next_token_id) # Yield the new token immediately token_text = self.tok.decode([next_token_id], skip_special_tokens=True) if token_text: yield token_text # Update session with final state - include all processed tokens final_input_ids = input_ids[0].tolist() + generated_tokens full_generated_text = self.tok.decode(generated_tokens, skip_special_tokens=True) session.update_cache(past_key_values, final_input_ids, prompt + full_generated_text) print(f"[KV_CACHE] Streamed {len(generated_tokens)} tokens with persistent KV cache") def clear_session_cache(self, session_id: str = "default") -> None: """Clear KV cache for a specific session""" session = self.get_session(session_id) session.invalidate_cache() print(f"[DEBUG] Cleared cache for session {session_id}") def get_session_info(self, session_id: str = "default") -> Dict[str, Any]: """Get information about a chat session""" session = self.get_session(session_id) return session.get_cache_info() def add_conversation_turn(self, user_message: str, assistant_message: str, session_id: str = "default") -> None: """Add a complete conversation turn to the session history""" session = self.get_session(session_id) session.add_message("user", user_message) session.add_message("assistant", assistant_message) def get_model_info(self) -> Dict[str, Any]: """Get comprehensive information about the loaded model""" try: # Basic model info info = { "model_name": getattr(self.model.config, 'name_or_path', 'unknown'), "model_type": self.model.config.model_type, "vocab_size": self.model.config.vocab_size, "device": str(self.device), "dtype": str(self.model.dtype), "supports_kv_cache": True, "max_position_embeddings": getattr(self.model.config, 'max_position_embeddings', 'unknown'), "torch_compile_enabled": hasattr(self.model, '_orig_mod') } # Memory info if self.torch.cuda.is_available() and str(self.device) != 'cpu': info.update({ "gpu_memory_allocated": f"{self.torch.cuda.memory_allocated() / 1024**3:.2f} GB", "gpu_memory_reserved": f"{self.torch.cuda.memory_reserved() / 1024**3:.2f} GB", "gpu_memory_total": f"{self.torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB" }) # Model parameters total_params = sum(p.numel() for p in self.model.parameters()) trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) info.update({ "total_parameters": f"{total_params:,}", "trainable_parameters": f"{trainable_params:,}", "model_size_mb": f"{total_params * 4 / 1024**2:.2f} MB" # Assuming float32 }) return info except Exception as e: return {"error": f"Could not get model info: {e}", "supports_kv_cache": True} def get_kv_cache_stats(self) -> Dict[str, Any]: """Get KV cache statistics across all sessions""" try: sessions = self.session_manager.get_all_sessions() stats = { "total_sessions": len(sessions), "active_sessions": 0, "total_cached_tokens": 0, "memory_usage_estimate": "0 MB" } for session_id in sessions: session = self.session_manager.get_session(session_id) if session.past_key_values is not None: stats["active_sessions"] += 1 stats["total_cached_tokens"] += session.context_length # Rough estimate of KV cache memory usage # Each token in KV cache roughly uses: hidden_size * num_layers * 2 (key + value) * 4 bytes (float32) if stats["total_cached_tokens"] > 0: try: hidden_size = self.model.config.hidden_size num_layers = self.model.config.num_hidden_layers memory_bytes = stats["total_cached_tokens"] * hidden_size * num_layers * 2 * 4 stats["memory_usage_estimate"] = f"{memory_bytes / 1024**2:.2f} MB" except: pass return stats except Exception as e: return {"error": f"Could not get KV cache stats: {e}"} def warm_up_model(self, test_prompt: str = "Hello") -> Dict[str, float]: """Warm up the model and measure performance metrics""" try: print("[KV_CACHE] Warming up model...") start_time = time.time() # Simple generation to warm up CUDA kernels cfg = GenerateConfig(max_tokens=5, temperature=0.0) _ = self.generate(test_prompt, cfg=cfg, session_id="warmup") warmup_time = time.time() - start_time # Clean up warmup session self.clear_session_cache("warmup") return { "warmup_time": warmup_time, "status": "success" } except Exception as e: return { "warmup_time": 0.0, "status": "failed", "error": str(e) } class HFTransformersLoader: name = "hf" def can_load(self, source: str, **kwargs: Any) -> bool: # Accept HF repo-id or local dir with config.json (covers .safetensors) if "/" in source and not os.path.exists(source): # repo-id like "microsoft/DialoGPT-medium" return True # Check if it's a directory with config.json if os.path.isdir(source) and os.path.exists(os.path.join(source, "config.json")): return True # If it's a single .safetensors file, look for config.json in the same directory if source.lower().endswith('.safetensors'): parent_dir = os.path.dirname(source) return os.path.exists(os.path.join(parent_dir, "config.json")) # Also support .bin files (PyTorch checkpoints) if source.lower().endswith('.bin'): parent_dir = os.path.dirname(source) return os.path.exists(os.path.join(parent_dir, "config.json")) return False def load(self, source: str, **kwargs: Any) -> UnifiedModel: return _HFUnified(source, **kwargs)