first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

View File

@@ -0,0 +1,132 @@
from typing import Any, Iterator, List, Optional
import os
from llm_runtime.types import UnifiedModel, GenerateConfig
class _AWQUnified:
def __init__(self, src: str, **kwargs: Any):
try:
from autoawq import AutoAWQForCausalLM
from transformers import AutoTokenizer
except ImportError:
raise ImportError("autoawq and transformers are required for AWQ models. Install with: pip install autoawq transformers")
print(f"Loading AWQ model from: {src}")
# Load tokenizer
self.tok = AutoTokenizer.from_pretrained(
src,
use_fast=True,
trust_remote_code=kwargs.get("trust_remote_code", True)
)
# Load AWQ model
self.model = AutoAWQForCausalLM.from_quantized(
src,
device_map=kwargs.get("device_map", "auto"),
trust_remote_code=kwargs.get("trust_remote_code", True),
safetensors=True,
)
# Set pad token if not present
if self.tok.pad_token is None:
self.tok.pad_token = self.tok.eos_token
def _build_eos_ids(self, stop) -> Optional[List[int]]:
"""Convert stop strings to token IDs"""
if not stop:
return None
ids = []
for s in stop:
if len(s) == 1:
tid = self.tok.convert_tokens_to_ids(s)
if tid is not None:
ids.append(tid)
return ids or None
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
from transformers import StoppingCriteria, StoppingCriteriaList
inputs = self.tok(prompt, return_tensors="pt").to(self.model.device)
class MultiStringStop(StoppingCriteria):
def __init__(self, toks, stops):
self.toks, self.stops = toks, stops or []
def __call__(self, input_ids, scores, **_):
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
return any(s in text for s in self.stops)
out_ids = self.model.generate(
**inputs,
max_new_tokens=cfg.max_tokens,
do_sample=cfg.temperature > 0.0,
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None,
pad_token_id=self.tok.pad_token_id,
)
# Decode only the new tokens
generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return generated_text
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
import threading
from transformers import TextIteratorStreamer
enc = self.tok(prompt, return_tensors="pt").to(self.model.device)
streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
def _worker():
self.model.generate(
**enc,
max_new_tokens=cfg.max_tokens,
do_sample=cfg.temperature > 0.0,
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
top_p=cfg.top_p,
eos_token_id=self._build_eos_ids(cfg.stop),
streamer=streamer,
pad_token_id=self.tok.pad_token_id,
)
t = threading.Thread(target=_worker)
t.start()
for text in streamer:
yield text
def tokenize(self, text: str) -> List[int]:
return self.tok.encode(text, add_special_tokens=False)
def detokenize(self, ids: List[int]) -> str:
return self.tok.decode(ids, skip_special_tokens=True)
class AWQLoader:
name = "awq"
def can_load(self, source: str, **kwargs: Any) -> bool:
# Check for AWQ indicators
if os.path.isdir(source):
# Look for AWQ indicators in config.json
try:
import json
config_path = os.path.join(source, "config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
config_str = str(config).lower()
return ("awq" in config_str or
"quantization_config" in config and
"awq" in str(config.get("quantization_config", {})).lower())
except:
pass
elif "/" in source and not os.path.exists(source): # HF repo
# For HF repos, we'll let it try if it looks like AWQ
return "awq" in source.lower()
return False
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
return _AWQUnified(source, **kwargs)