first commit
This commit is contained in:
132
llm_runtime/loaders/awq_loader.py
Normal file
132
llm_runtime/loaders/awq_loader.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import Any, Iterator, List, Optional
|
||||
import os
|
||||
from llm_runtime.types import UnifiedModel, GenerateConfig
|
||||
|
||||
class _AWQUnified:
|
||||
def __init__(self, src: str, **kwargs: Any):
|
||||
try:
|
||||
from autoawq import AutoAWQForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
raise ImportError("autoawq and transformers are required for AWQ models. Install with: pip install autoawq transformers")
|
||||
|
||||
print(f"Loading AWQ model from: {src}")
|
||||
|
||||
# Load tokenizer
|
||||
self.tok = AutoTokenizer.from_pretrained(
|
||||
src,
|
||||
use_fast=True,
|
||||
trust_remote_code=kwargs.get("trust_remote_code", True)
|
||||
)
|
||||
|
||||
# Load AWQ model
|
||||
self.model = AutoAWQForCausalLM.from_quantized(
|
||||
src,
|
||||
device_map=kwargs.get("device_map", "auto"),
|
||||
trust_remote_code=kwargs.get("trust_remote_code", True),
|
||||
safetensors=True,
|
||||
)
|
||||
|
||||
# Set pad token if not present
|
||||
if self.tok.pad_token is None:
|
||||
self.tok.pad_token = self.tok.eos_token
|
||||
|
||||
def _build_eos_ids(self, stop) -> Optional[List[int]]:
|
||||
"""Convert stop strings to token IDs"""
|
||||
if not stop:
|
||||
return None
|
||||
|
||||
ids = []
|
||||
for s in stop:
|
||||
if len(s) == 1:
|
||||
tid = self.tok.convert_tokens_to_ids(s)
|
||||
if tid is not None:
|
||||
ids.append(tid)
|
||||
return ids or None
|
||||
|
||||
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
inputs = self.tok(prompt, return_tensors="pt").to(self.model.device)
|
||||
|
||||
class MultiStringStop(StoppingCriteria):
|
||||
def __init__(self, toks, stops):
|
||||
self.toks, self.stops = toks, stops or []
|
||||
|
||||
def __call__(self, input_ids, scores, **_):
|
||||
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
|
||||
return any(s in text for s in self.stops)
|
||||
|
||||
out_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=cfg.max_tokens,
|
||||
do_sample=cfg.temperature > 0.0,
|
||||
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
|
||||
top_p=cfg.top_p,
|
||||
eos_token_id=self._build_eos_ids(cfg.stop),
|
||||
stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None,
|
||||
pad_token_id=self.tok.pad_token_id,
|
||||
)
|
||||
|
||||
# Decode only the new tokens
|
||||
generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
return generated_text
|
||||
|
||||
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
|
||||
import threading
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
enc = self.tok(prompt, return_tensors="pt").to(self.model.device)
|
||||
streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
def _worker():
|
||||
self.model.generate(
|
||||
**enc,
|
||||
max_new_tokens=cfg.max_tokens,
|
||||
do_sample=cfg.temperature > 0.0,
|
||||
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
|
||||
top_p=cfg.top_p,
|
||||
eos_token_id=self._build_eos_ids(cfg.stop),
|
||||
streamer=streamer,
|
||||
pad_token_id=self.tok.pad_token_id,
|
||||
)
|
||||
|
||||
t = threading.Thread(target=_worker)
|
||||
t.start()
|
||||
|
||||
for text in streamer:
|
||||
yield text
|
||||
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
return self.tok.encode(text, add_special_tokens=False)
|
||||
|
||||
def detokenize(self, ids: List[int]) -> str:
|
||||
return self.tok.decode(ids, skip_special_tokens=True)
|
||||
|
||||
class AWQLoader:
|
||||
name = "awq"
|
||||
|
||||
def can_load(self, source: str, **kwargs: Any) -> bool:
|
||||
# Check for AWQ indicators
|
||||
if os.path.isdir(source):
|
||||
# Look for AWQ indicators in config.json
|
||||
try:
|
||||
import json
|
||||
config_path = os.path.join(source, "config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
config_str = str(config).lower()
|
||||
return ("awq" in config_str or
|
||||
"quantization_config" in config and
|
||||
"awq" in str(config.get("quantization_config", {})).lower())
|
||||
except:
|
||||
pass
|
||||
elif "/" in source and not os.path.exists(source): # HF repo
|
||||
# For HF repos, we'll let it try if it looks like AWQ
|
||||
return "awq" in source.lower()
|
||||
|
||||
return False
|
||||
|
||||
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
|
||||
return _AWQUnified(source, **kwargs)
|
||||
Reference in New Issue
Block a user