Files
dark_hal/llm_runtime/device_utils.py

99 lines
3.7 KiB
Python
Raw Normal View History

2026-03-13 12:56:43 -07:00
import torch
from typing import Union, Literal
Backend = Literal["hf", "gptq"]
DevIn = Union[None, str, int]
DevOut = Union[str, int]
def _has_mps() -> bool:
return getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available()
def _first_cuda_index() -> int | None:
return 0 if torch.cuda.is_available() and torch.cuda.device_count() > 0 else None
def normalize_device(dev: DevIn = None, *, backend: Backend = "hf") -> DevOut:
"""
Normalize a user/device string/int into what each backend expects.
Inputs accepted:
None | "auto" | "cpu" | "mps" | "disk" | "cuda" | "cuda:N" | N (int)
Returns:
backend == "hf": "cpu" | "mps" | "cuda:N"
backend == "gptq": "cpu" | "mps" | "disk" | N (int)
"""
# 1) Auto/default
if dev in (None, "auto"):
cuda0 = _first_cuda_index()
if cuda0 is not None:
return (cuda0 if backend == "gptq" else f"cuda:{cuda0}")
if _has_mps():
return "mps"
# AutoGPTQ can also run with 'disk' offload if caller wants; default to CPU here
return "cpu"
# 2) Explicit CPU/MPS/DISK
if isinstance(dev, str) and dev.lower() in {"cpu", "mps", "disk"}:
# HF does not know "disk"; treat as CPU for HF branch
return dev if backend == "gptq" or dev != "disk" else "cpu"
# 3) Explicit CUDA string
if isinstance(dev, str) and dev.lower().startswith("cuda"):
# Accept "cuda" and "cuda:N"
if dev == "cuda":
idx = _first_cuda_index()
if idx is None:
# No CUDA available; degrade to CPU/MPS appropriately
return "cpu" if backend == "hf" else "cpu"
return (idx if backend == "gptq" else f"cuda:{idx}")
# cuda:N
try:
idx = int(dev.split(":", 1)[1])
except (IndexError, ValueError):
raise ValueError(f"Bad CUDA device string: {dev!r}. Use 'cuda' or 'cuda:N'.")
return (idx if backend == "gptq" else f"cuda:{idx}")
# 4) Integer GPU index
if isinstance(dev, int):
if dev < 0:
raise ValueError(f"GPU index must be >= 0, got {dev}")
return (dev if backend == "gptq" else f"cuda:{dev}")
raise ValueError(f"Unsupported device spec for backend={backend!r}: {dev!r}")
# --- Convenience wrappers -------------------------------------------------------
def device_for_hf(dev: DevIn = None) -> str:
"""Return a device string suitable for HuggingFace (e.g., 'cuda:0', 'cpu', 'mps')."""
out = normalize_device(dev, backend="hf")
assert isinstance(out, str)
return out
def device_for_gptq(dev: DevIn = None) -> Union[int, str]:
"""Return an int GPU index or 'cpu'/'mps'/'disk' for AutoGPTQ."""
out = normalize_device(dev, backend="gptq")
assert isinstance(out, (int, str))
return out
def debug_device_placement(model, name="model"):
"""Debug helper to check where model parameters are placed"""
try:
devices = set()
for name_param, param in model.named_parameters():
devices.add(str(param.device))
print(f"[DEBUG] {name} parameters on devices: {devices}")
# Check first parameter device
first_param = next(model.parameters())
print(f"[DEBUG] {name} primary device: {first_param.device}")
return first_param.device
except Exception as e:
print(f"[DEBUG] Could not check {name} device placement: {e}")
return None
# --- Minimal self-test (run this file directly) ---------------------------------
if __name__ == "__main__":
tests = [None, "auto", "cpu", "mps", "disk", "cuda", "cuda:0", "cuda:1", 0, 1]
for t in tests:
print(f"in={t!r:7} -> hf={device_for_hf(t)!r:7} gptq={device_for_gptq(t)!r}")