first commit
This commit is contained in:
99
llm_runtime/device_utils.py
Normal file
99
llm_runtime/device_utils.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
from typing import Union, Literal
|
||||
|
||||
Backend = Literal["hf", "gptq"]
|
||||
DevIn = Union[None, str, int]
|
||||
DevOut = Union[str, int]
|
||||
|
||||
def _has_mps() -> bool:
|
||||
return getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available()
|
||||
|
||||
def _first_cuda_index() -> int | None:
|
||||
return 0 if torch.cuda.is_available() and torch.cuda.device_count() > 0 else None
|
||||
|
||||
def normalize_device(dev: DevIn = None, *, backend: Backend = "hf") -> DevOut:
|
||||
"""
|
||||
Normalize a user/device string/int into what each backend expects.
|
||||
|
||||
Inputs accepted:
|
||||
None | "auto" | "cpu" | "mps" | "disk" | "cuda" | "cuda:N" | N (int)
|
||||
|
||||
Returns:
|
||||
backend == "hf": "cpu" | "mps" | "cuda:N"
|
||||
backend == "gptq": "cpu" | "mps" | "disk" | N (int)
|
||||
"""
|
||||
# 1) Auto/default
|
||||
if dev in (None, "auto"):
|
||||
cuda0 = _first_cuda_index()
|
||||
if cuda0 is not None:
|
||||
return (cuda0 if backend == "gptq" else f"cuda:{cuda0}")
|
||||
if _has_mps():
|
||||
return "mps"
|
||||
# AutoGPTQ can also run with 'disk' offload if caller wants; default to CPU here
|
||||
return "cpu"
|
||||
|
||||
# 2) Explicit CPU/MPS/DISK
|
||||
if isinstance(dev, str) and dev.lower() in {"cpu", "mps", "disk"}:
|
||||
# HF does not know "disk"; treat as CPU for HF branch
|
||||
return dev if backend == "gptq" or dev != "disk" else "cpu"
|
||||
|
||||
# 3) Explicit CUDA string
|
||||
if isinstance(dev, str) and dev.lower().startswith("cuda"):
|
||||
# Accept "cuda" and "cuda:N"
|
||||
if dev == "cuda":
|
||||
idx = _first_cuda_index()
|
||||
if idx is None:
|
||||
# No CUDA available; degrade to CPU/MPS appropriately
|
||||
return "cpu" if backend == "hf" else "cpu"
|
||||
return (idx if backend == "gptq" else f"cuda:{idx}")
|
||||
# cuda:N
|
||||
try:
|
||||
idx = int(dev.split(":", 1)[1])
|
||||
except (IndexError, ValueError):
|
||||
raise ValueError(f"Bad CUDA device string: {dev!r}. Use 'cuda' or 'cuda:N'.")
|
||||
return (idx if backend == "gptq" else f"cuda:{idx}")
|
||||
|
||||
# 4) Integer GPU index
|
||||
if isinstance(dev, int):
|
||||
if dev < 0:
|
||||
raise ValueError(f"GPU index must be >= 0, got {dev}")
|
||||
return (dev if backend == "gptq" else f"cuda:{dev}")
|
||||
|
||||
raise ValueError(f"Unsupported device spec for backend={backend!r}: {dev!r}")
|
||||
|
||||
# --- Convenience wrappers -------------------------------------------------------
|
||||
|
||||
def device_for_hf(dev: DevIn = None) -> str:
|
||||
"""Return a device string suitable for HuggingFace (e.g., 'cuda:0', 'cpu', 'mps')."""
|
||||
out = normalize_device(dev, backend="hf")
|
||||
assert isinstance(out, str)
|
||||
return out
|
||||
|
||||
def device_for_gptq(dev: DevIn = None) -> Union[int, str]:
|
||||
"""Return an int GPU index or 'cpu'/'mps'/'disk' for AutoGPTQ."""
|
||||
out = normalize_device(dev, backend="gptq")
|
||||
assert isinstance(out, (int, str))
|
||||
return out
|
||||
|
||||
def debug_device_placement(model, name="model"):
|
||||
"""Debug helper to check where model parameters are placed"""
|
||||
try:
|
||||
devices = set()
|
||||
for name_param, param in model.named_parameters():
|
||||
devices.add(str(param.device))
|
||||
print(f"[DEBUG] {name} parameters on devices: {devices}")
|
||||
|
||||
# Check first parameter device
|
||||
first_param = next(model.parameters())
|
||||
print(f"[DEBUG] {name} primary device: {first_param.device}")
|
||||
|
||||
return first_param.device
|
||||
except Exception as e:
|
||||
print(f"[DEBUG] Could not check {name} device placement: {e}")
|
||||
return None
|
||||
|
||||
# --- Minimal self-test (run this file directly) ---------------------------------
|
||||
if __name__ == "__main__":
|
||||
tests = [None, "auto", "cpu", "mps", "disk", "cuda", "cuda:0", "cuda:1", 0, 1]
|
||||
for t in tests:
|
||||
print(f"in={t!r:7} -> hf={device_for_hf(t)!r:7} gptq={device_for_gptq(t)!r}")
|
||||
Reference in New Issue
Block a user