176 lines
6.7 KiB
Python
176 lines
6.7 KiB
Python
|
|
from typing import Any, Iterator, List, Optional
|
||
|
|
import os, json
|
||
|
|
import torch
|
||
|
|
from llm_runtime.types import UnifiedModel, GenerateConfig
|
||
|
|
from llm_runtime.device_utils import device_for_gptq
|
||
|
|
|
||
|
|
def _inputs_device_from_gptq_device(dev) -> torch.device:
|
||
|
|
"""
|
||
|
|
AutoGPTQ wants: int GPU index | 'cpu' | 'mps' | 'disk'
|
||
|
|
But tokenizer tensors need a torch.device:
|
||
|
|
- int -> 'cuda:{idx}'
|
||
|
|
- 'cpu' -> 'cpu'
|
||
|
|
- 'mps' -> 'mps'
|
||
|
|
- 'disk' -> still run the forward on CUDA/CPU; safest default: 'cpu'
|
||
|
|
"""
|
||
|
|
if isinstance(dev, int):
|
||
|
|
return torch.device(f"cuda:{dev}")
|
||
|
|
if isinstance(dev, str):
|
||
|
|
if dev in ("cpu", "mps"):
|
||
|
|
return torch.device(dev)
|
||
|
|
if dev == "disk":
|
||
|
|
# inputs live on CPU; model will page as needed
|
||
|
|
return torch.device("cpu")
|
||
|
|
# Fallback
|
||
|
|
return torch.device("cpu")
|
||
|
|
|
||
|
|
class _GPTQUnified:
|
||
|
|
def __init__(self, src: str, **kwargs: Any):
|
||
|
|
print(f"[GPTQ_DEBUG] _GPTQUnified.__init__() called with src='{src}', kwargs={kwargs}")
|
||
|
|
try:
|
||
|
|
from auto_gptq import AutoGPTQForCausalLM
|
||
|
|
from transformers import AutoTokenizer
|
||
|
|
print("[GPTQ_DEBUG] Successfully imported auto_gptq and transformers")
|
||
|
|
except ImportError as e:
|
||
|
|
print(f"[GPTQ_DEBUG] Failed to import auto_gptq or transformers: {e}")
|
||
|
|
raise ImportError("auto-gptq and transformers are required for GPTQ models. Install with: pip install auto-gptq transformers")
|
||
|
|
|
||
|
|
print(f"[GPTQ_DEBUG] Loading GPTQ model from: {src}")
|
||
|
|
|
||
|
|
# Normalize device for AutoGPTQ (int or 'cpu'/'mps'/'disk')
|
||
|
|
raw_device = kwargs.get("device")
|
||
|
|
print(f"[GPTQ_DEBUG] Raw device from kwargs: {raw_device}")
|
||
|
|
self._gptq_dev = device_for_gptq(raw_device)
|
||
|
|
print(f"[GPTQ_DEBUG] Normalized device for GPTQ: {self._gptq_dev}")
|
||
|
|
self._inputs_device = _inputs_device_from_gptq_device(self._gptq_dev)
|
||
|
|
print(f"[GPTQ_DEBUG] Using device (gptq): {self._gptq_dev} | inputs will go to: {self._inputs_device}")
|
||
|
|
|
||
|
|
trust_remote = kwargs.get("trust_remote_code", True)
|
||
|
|
token = kwargs.get("token")
|
||
|
|
|
||
|
|
# Tokenizer
|
||
|
|
self.tok = AutoTokenizer.from_pretrained(
|
||
|
|
src,
|
||
|
|
use_fast=True,
|
||
|
|
trust_remote_code=trust_remote,
|
||
|
|
token=token
|
||
|
|
)
|
||
|
|
|
||
|
|
# Model
|
||
|
|
self.model = AutoGPTQForCausalLM.from_quantized(
|
||
|
|
src,
|
||
|
|
device=self._gptq_dev, # int or 'cpu'/'mps'/'disk'
|
||
|
|
trust_remote_code=trust_remote,
|
||
|
|
use_safetensors=True,
|
||
|
|
use_triton=kwargs.get("use_triton", False),
|
||
|
|
token=token
|
||
|
|
)
|
||
|
|
|
||
|
|
# Pad token safety
|
||
|
|
if self.tok.pad_token is None:
|
||
|
|
self.tok.pad_token = self.tok.eos_token
|
||
|
|
|
||
|
|
def _build_eos_ids(self, stop) -> Optional[List[int]]:
|
||
|
|
"""Encode stop strings to token IDs (take the first token of each stop string)."""
|
||
|
|
if not stop:
|
||
|
|
return None
|
||
|
|
out: List[int] = []
|
||
|
|
for s in stop:
|
||
|
|
if not s:
|
||
|
|
continue
|
||
|
|
ids = self.tok.encode(s, add_special_tokens=False)
|
||
|
|
if ids:
|
||
|
|
out.append(ids[0])
|
||
|
|
return out or None
|
||
|
|
|
||
|
|
def generate(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> str:
|
||
|
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
||
|
|
|
||
|
|
enc = self.tok(prompt, return_tensors="pt").to(self._inputs_device)
|
||
|
|
|
||
|
|
class MultiStringStop(StoppingCriteria):
|
||
|
|
def __init__(self, toks, stops):
|
||
|
|
self.toks, self.stops = toks, stops or []
|
||
|
|
def __call__(self, input_ids, scores, **_):
|
||
|
|
# Simple but effective; for high perf, implement a token-level matcher.
|
||
|
|
text = self.toks.decode(input_ids[0], skip_special_tokens=True)
|
||
|
|
return any(s in text for s in self.stops)
|
||
|
|
|
||
|
|
do_sample = cfg.temperature is not None and cfg.temperature > 0.0
|
||
|
|
gen_kwargs = dict(
|
||
|
|
max_new_tokens=cfg.max_tokens,
|
||
|
|
do_sample=do_sample,
|
||
|
|
top_p=cfg.top_p,
|
||
|
|
eos_token_id=self._build_eos_ids(cfg.stop),
|
||
|
|
pad_token_id=self.tok.pad_token_id,
|
||
|
|
)
|
||
|
|
if do_sample:
|
||
|
|
gen_kwargs["temperature"] = float(cfg.temperature)
|
||
|
|
|
||
|
|
if cfg.stop:
|
||
|
|
gen_kwargs["stopping_criteria"] = StoppingCriteriaList([MultiStringStop(self.tok, cfg.stop)])
|
||
|
|
|
||
|
|
out_ids = self.model.generate(**enc, **gen_kwargs)
|
||
|
|
|
||
|
|
# Decode only new tokens
|
||
|
|
new_tokens = out_ids[0][enc.input_ids.shape[1]:]
|
||
|
|
return self.tok.decode(new_tokens, skip_special_tokens=True)
|
||
|
|
|
||
|
|
def stream(self, prompt: str, cfg: GenerateConfig = GenerateConfig(), **kwargs: Any) -> Iterator[str]:
|
||
|
|
import threading
|
||
|
|
from transformers import TextIteratorStreamer
|
||
|
|
|
||
|
|
enc = self.tok(prompt, return_tensors="pt").to(self._inputs_device)
|
||
|
|
streamer = TextIteratorStreamer(self.tok, skip_prompt=True, skip_special_tokens=True)
|
||
|
|
|
||
|
|
do_sample = cfg.temperature is not None and cfg.temperature > 0.0
|
||
|
|
gen_kwargs = dict(
|
||
|
|
max_new_tokens=cfg.max_tokens,
|
||
|
|
do_sample=do_sample,
|
||
|
|
top_p=cfg.top_p,
|
||
|
|
eos_token_id=self._build_eos_ids(cfg.stop),
|
||
|
|
streamer=streamer,
|
||
|
|
pad_token_id=self.tok.pad_token_id,
|
||
|
|
)
|
||
|
|
if do_sample:
|
||
|
|
gen_kwargs["temperature"] = float(cfg.temperature)
|
||
|
|
|
||
|
|
def _worker():
|
||
|
|
self.model.generate(**enc, **gen_kwargs)
|
||
|
|
|
||
|
|
t = threading.Thread(target=_worker, daemon=True)
|
||
|
|
t.start()
|
||
|
|
for chunk in streamer:
|
||
|
|
yield chunk
|
||
|
|
|
||
|
|
def tokenize(self, text: str) -> List[int]:
|
||
|
|
return self.tok.encode(text, add_special_tokens=False)
|
||
|
|
|
||
|
|
def detokenize(self, ids: List[int]) -> str:
|
||
|
|
return self.tok.decode(ids, skip_special_tokens=True)
|
||
|
|
|
||
|
|
class AutoGPTQLoader:
|
||
|
|
name = "gptq"
|
||
|
|
|
||
|
|
def can_load(self, source: str, **kwargs: Any) -> bool:
|
||
|
|
# Local folder?
|
||
|
|
if os.path.isdir(source):
|
||
|
|
# quantize_config.json is a strong GPTQ signal
|
||
|
|
qc = os.path.join(source, "quantize_config.json")
|
||
|
|
if os.path.exists(qc):
|
||
|
|
return True
|
||
|
|
# Fallback: peek at config.json
|
||
|
|
try:
|
||
|
|
with open(os.path.join(source, "config.json"), "r", encoding="utf-8") as f:
|
||
|
|
cfg = json.load(f)
|
||
|
|
text = json.dumps(cfg).lower()
|
||
|
|
return ("gptq" in text) or ("quantize_config" in text) or ("quant_method" in text)
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
# HF repo path (heuristic): let the loader try
|
||
|
|
elif "/" in source and not os.path.exists(source):
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
def load(self, source: str, **kwargs: Any) -> UnifiedModel:
|
||
|
|
return _GPTQUnified(source, **kwargs)
|