132 lines
5.0 KiB
Python
132 lines
5.0 KiB
Python
from typing import Any, Iterator, List, Optional
|
|
import os
|
|
from llm_runtime.types import UnifiedModel, GenerateConfig
|
|
|
|
class _AWQUnified:
|
|
def __init__(self, src: str, **kwargs: Any):
|
|
try:
|
|
from autoawq import AutoAWQForCausalLM
|
|
from transformers import AutoTokenizer
|
|
except ImportError:
|
|
raise ImportError("autoawq and transformers are required for AWQ models. Install with: pip install autoawq transformers")
|
|
|
|
print(f"Loading AWQ model from: {src}")
|
|
|
|
# Load tokenizer
|
|
self.tok = AutoTokenizer.from_pretrained(
|
|
src,
|
|
use_fast=True,
|
|
trust_remote_code=kwargs.get("trust_remote_code", True)
|
|
)
|
|
|
|
# Load AWQ model
|
|
self.model = AutoAWQForCausalLM.from_quantized(
|
|
src,
|
|
device_map=kwargs.get("device_map", "auto"),
|
|
trust_remote_code=kwargs.get("trust_remote_code", True),
|
|
safetensors=True,
|
|
)
|
|
|
|
# Set pad token if not present
|
|
if self.tok.pad_token is None:
|
|
self.tok.pad_token = self.tok.eos_token
|
|
|
|
def _build_eos_ids(self, stop) -> Optional[List[int]]:
|
|
"""Convert stop strings to token IDs"""
|
|
if not stop:
|
|
return None
|
|
|
|
ids = []
|
|
for s in stop:
|
|
if len(s) == 1:
|
|
tid = self.tok.convert_tokens_to_ids(s)
|
|
if tid is not None:
|
|
ids.append(tid)
|
|
return ids or None
|
|
|
|
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
|
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
|
|
|
inputs = self.tok(prompt, return_tensors="pt").to(self.model.device)
|
|
|
|
class MultiStringStop(StoppingCriteria):
|
|
def __init__(self, toks, stops):
|
|
self.toks, self.stops = toks, stops or []
|
|
|
|
def __call__(self, input_ids, scores, **_):
|
|
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
|
|
return any(s in text for s in self.stops)
|
|
|
|
out_ids = self.model.generate(
|
|
**inputs,
|
|
max_new_tokens=cfg.max_tokens,
|
|
do_sample=cfg.temperature > 0.0,
|
|
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
|
|
top_p=cfg.top_p,
|
|
eos_token_id=self._build_eos_ids(cfg.stop),
|
|
stopping_criteria=StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)]) if cfg.stop else None,
|
|
pad_token_id=self.tok.pad_token_id,
|
|
)
|
|
|
|
# Decode only the new tokens
|
|
generated_text = self.tok.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
|
return generated_text
|
|
|
|
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
|
|
import threading
|
|
from transformers import TextIteratorStreamer
|
|
|
|
enc = self.tok(prompt, return_tensors="pt").to(self.model.device)
|
|
streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
|
|
|
|
def _worker():
|
|
self.model.generate(
|
|
**enc,
|
|
max_new_tokens=cfg.max_tokens,
|
|
do_sample=cfg.temperature > 0.0,
|
|
temperature=cfg.temperature if cfg.temperature > 0.0 else None,
|
|
top_p=cfg.top_p,
|
|
eos_token_id=self._build_eos_ids(cfg.stop),
|
|
streamer=streamer,
|
|
pad_token_id=self.tok.pad_token_id,
|
|
)
|
|
|
|
t = threading.Thread(target=_worker)
|
|
t.start()
|
|
|
|
for text in streamer:
|
|
yield text
|
|
|
|
def tokenize(self, text: str) -> List[int]:
|
|
return self.tok.encode(text, add_special_tokens=False)
|
|
|
|
def detokenize(self, ids: List[int]) -> str:
|
|
return self.tok.decode(ids, skip_special_tokens=True)
|
|
|
|
class AWQLoader:
|
|
name = "awq"
|
|
|
|
def can_load(self, source: str, **kwargs: Any) -> bool:
|
|
# Check for AWQ indicators
|
|
if os.path.isdir(source):
|
|
# Look for AWQ indicators in config.json
|
|
try:
|
|
import json
|
|
config_path = os.path.join(source, "config.json")
|
|
if os.path.exists(config_path):
|
|
with open(config_path, 'r') as f:
|
|
config = json.load(f)
|
|
config_str = str(config).lower()
|
|
return ("awq" in config_str or
|
|
"quantization_config" in config and
|
|
"awq" in str(config.get("quantization_config", {})).lower())
|
|
except:
|
|
pass
|
|
elif "/" in source and not os.path.exists(source): # HF repo
|
|
# For HF repos, we'll let it try if it looks like AWQ
|
|
return "awq" in source.lower()
|
|
|
|
return False
|
|
|
|
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
|
|
return _AWQUnified(source, **kwargs) |