99 lines
3.7 KiB
Python
99 lines
3.7 KiB
Python
|
|
import torch
|
||
|
|
from typing import Union, Literal
|
||
|
|
|
||
|
|
Backend = Literal["hf", "gptq"]
|
||
|
|
DevIn = Union[None, str, int]
|
||
|
|
DevOut = Union[str, int]
|
||
|
|
|
||
|
|
def _has_mps() -> bool:
|
||
|
|
return getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available()
|
||
|
|
|
||
|
|
def _first_cuda_index() -> int | None:
|
||
|
|
return 0 if torch.cuda.is_available() and torch.cuda.device_count() > 0 else None
|
||
|
|
|
||
|
|
def normalize_device(dev: DevIn = None, *, backend: Backend = "hf") -> DevOut:
|
||
|
|
"""
|
||
|
|
Normalize a user/device string/int into what each backend expects.
|
||
|
|
|
||
|
|
Inputs accepted:
|
||
|
|
None | "auto" | "cpu" | "mps" | "disk" | "cuda" | "cuda:N" | N (int)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
backend == "hf": "cpu" | "mps" | "cuda:N"
|
||
|
|
backend == "gptq": "cpu" | "mps" | "disk" | N (int)
|
||
|
|
"""
|
||
|
|
# 1) Auto/default
|
||
|
|
if dev in (None, "auto"):
|
||
|
|
cuda0 = _first_cuda_index()
|
||
|
|
if cuda0 is not None:
|
||
|
|
return (cuda0 if backend == "gptq" else f"cuda:{cuda0}")
|
||
|
|
if _has_mps():
|
||
|
|
return "mps"
|
||
|
|
# AutoGPTQ can also run with 'disk' offload if caller wants; default to CPU here
|
||
|
|
return "cpu"
|
||
|
|
|
||
|
|
# 2) Explicit CPU/MPS/DISK
|
||
|
|
if isinstance(dev, str) and dev.lower() in {"cpu", "mps", "disk"}:
|
||
|
|
# HF does not know "disk"; treat as CPU for HF branch
|
||
|
|
return dev if backend == "gptq" or dev != "disk" else "cpu"
|
||
|
|
|
||
|
|
# 3) Explicit CUDA string
|
||
|
|
if isinstance(dev, str) and dev.lower().startswith("cuda"):
|
||
|
|
# Accept "cuda" and "cuda:N"
|
||
|
|
if dev == "cuda":
|
||
|
|
idx = _first_cuda_index()
|
||
|
|
if idx is None:
|
||
|
|
# No CUDA available; degrade to CPU/MPS appropriately
|
||
|
|
return "cpu" if backend == "hf" else "cpu"
|
||
|
|
return (idx if backend == "gptq" else f"cuda:{idx}")
|
||
|
|
# cuda:N
|
||
|
|
try:
|
||
|
|
idx = int(dev.split(":", 1)[1])
|
||
|
|
except (IndexError, ValueError):
|
||
|
|
raise ValueError(f"Bad CUDA device string: {dev!r}. Use 'cuda' or 'cuda:N'.")
|
||
|
|
return (idx if backend == "gptq" else f"cuda:{idx}")
|
||
|
|
|
||
|
|
# 4) Integer GPU index
|
||
|
|
if isinstance(dev, int):
|
||
|
|
if dev < 0:
|
||
|
|
raise ValueError(f"GPU index must be >= 0, got {dev}")
|
||
|
|
return (dev if backend == "gptq" else f"cuda:{dev}")
|
||
|
|
|
||
|
|
raise ValueError(f"Unsupported device spec for backend={backend!r}: {dev!r}")
|
||
|
|
|
||
|
|
# --- Convenience wrappers -------------------------------------------------------
|
||
|
|
|
||
|
|
def device_for_hf(dev: DevIn = None) -> str:
|
||
|
|
"""Return a device string suitable for HuggingFace (e.g., 'cuda:0', 'cpu', 'mps')."""
|
||
|
|
out = normalize_device(dev, backend="hf")
|
||
|
|
assert isinstance(out, str)
|
||
|
|
return out
|
||
|
|
|
||
|
|
def device_for_gptq(dev: DevIn = None) -> Union[int, str]:
|
||
|
|
"""Return an int GPU index or 'cpu'/'mps'/'disk' for AutoGPTQ."""
|
||
|
|
out = normalize_device(dev, backend="gptq")
|
||
|
|
assert isinstance(out, (int, str))
|
||
|
|
return out
|
||
|
|
|
||
|
|
def debug_device_placement(model, name="model"):
|
||
|
|
"""Debug helper to check where model parameters are placed"""
|
||
|
|
try:
|
||
|
|
devices = set()
|
||
|
|
for name_param, param in model.named_parameters():
|
||
|
|
devices.add(str(param.device))
|
||
|
|
print(f"[DEBUG] {name} parameters on devices: {devices}")
|
||
|
|
|
||
|
|
# Check first parameter device
|
||
|
|
first_param = next(model.parameters())
|
||
|
|
print(f"[DEBUG] {name} primary device: {first_param.device}")
|
||
|
|
|
||
|
|
return first_param.device
|
||
|
|
except Exception as e:
|
||
|
|
print(f"[DEBUG] Could not check {name} device placement: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
# --- Minimal self-test (run this file directly) ---------------------------------
|
||
|
|
if __name__ == "__main__":
|
||
|
|
tests = [None, "auto", "cpu", "mps", "disk", "cuda", "cuda:0", "cuda:1", 0, 1]
|
||
|
|
for t in tests:
|
||
|
|
print(f"in={t!r:7} -> hf={device_for_hf(t)!r:7} gptq={device_for_gptq(t)!r}")
|