first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

View File

@@ -0,0 +1,137 @@
from typing import Any, Iterator, List, Optional
import os
from llm_runtime.types import UnifiedModel, GenerateConfig
class _ExLlama2Unified:
def __init__(self, src: str, **kwargs: Any):
try:
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer, ExLlamaV2Cache
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
except ImportError:
raise ImportError("exllamav2 is required for EXL2 models. Install with: pip install exllamav2")
print(f"Loading EXL2 model from: {src}")
# Configure model
self.config = ExLlamaV2Config(src)
# Apply any config overrides
if "max_seq_len" in kwargs:
self.config.max_seq_len = kwargs["max_seq_len"]
if "scale_pos_emb" in kwargs:
self.config.scale_pos_emb = kwargs["scale_pos_emb"]
if "scale_alpha_value" in kwargs:
self.config.scale_alpha_value = kwargs["scale_alpha_value"]
# Initialize model
self.model = ExLlamaV2(self.config)
# Load model weights
self.model.load()
# Initialize tokenizer
self.tokenizer = ExLlamaV2Tokenizer(self.config)
# Initialize cache
self.cache = ExLlamaV2Cache(self.model, lazy=kwargs.get("lazy_cache", True))
# Initialize generator for streaming
self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer)
print(f"EXL2 model loaded successfully")
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
# Create sampler settings
settings = ExLlamaV2Sampler.Settings()
settings.temperature = cfg.temperature
settings.top_p = cfg.top_p
# Set stop conditions
if cfg.stop:
# ExLlamaV2 expects stop strings as a list
stop_conditions = list(cfg.stop)
else:
stop_conditions = []
# Generate text
output = self.generator.generate_simple(
prompt=prompt,
max_new_tokens=cfg.max_tokens,
seed=kwargs.get("seed", -1),
token_healing=kwargs.get("token_healing", True),
temperature=cfg.temperature,
top_p=cfg.top_p,
stop_conditions=stop_conditions,
)
return output
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
# Create sampler settings
settings = ExLlamaV2Sampler.Settings()
settings.temperature = cfg.temperature
settings.top_p = cfg.top_p
# Set stop conditions
if cfg.stop:
stop_conditions = list(cfg.stop)
else:
stop_conditions = []
# Begin streaming generation
input_ids = self.tokenizer.encode(prompt)
self.generator.begin_stream(
input_ids=input_ids,
gen_settings=settings,
token_healing=kwargs.get("token_healing", True),
seed=kwargs.get("seed", -1),
)
generated_tokens = 0
while generated_tokens < cfg.max_tokens:
chunk, eos, tokens = self.generator.stream()
if chunk:
yield chunk
generated_tokens += tokens
# Check for stop conditions
if eos:
break
if cfg.stop:
# Check if any stop condition is met in the generated text so far
current_text = self.generator.sequence_str
if any(stop in current_text for stop in cfg.stop):
break
def tokenize(self, text: str) -> List[int]:
return self.tokenizer.encode(text).tolist()
def detokenize(self, ids: List[int]) -> str:
import torch
tensor_ids = torch.tensor([ids], dtype=torch.long)
return self.tokenizer.decode(tensor_ids)[0]
class ExLlama2Loader:
name = "exllama2"
def can_load(self, source: str, **kwargs: Any) -> bool:
# Check for EXL2 model directory structure
if os.path.isdir(source):
# Look for config.json and .safetensors files that indicate EXL2
config_path = os.path.join(source, "config.json")
if os.path.exists(config_path):
# Check if there are .safetensors files with EXL2 naming pattern
for file in os.listdir(source):
if file.endswith(".safetensors") and ("model" in file.lower() or "exl2" in file.lower()):
return True
elif source.lower().endswith(".exl2"):
return True
return False
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _ExLlama2Unified(source, **kwargs)