[startup] auto-select GPU backend

This commit is contained in:
AlpinDale
2025-12-08 23:52:02 +00:00
parent 8b6b793bfc
commit 76ffc7c458

View File

@@ -47,17 +47,32 @@ def get_install_features(lib_name: str = None):
possible_features = ["cu12", "amd"]
if not lib_name:
# Ask the user for the GPU lib
gpu_lib_choices = {
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu12"},
"B": {"pretty": "AMD", "internal": "amd"},
}
user_input = get_user_choice(
"Select your GPU. If you don't know, select Cuda 12.x (A)",
gpu_lib_choices,
)
has_nvidia = which("nvidia-smi") is not None
has_rocm = which("rocm-smi") is not None
has_amd = which("amd-smi") is not None
has_amd_gpu = has_rocm or has_amd
lib_name = gpu_lib_choices.get(user_input, {}).get("internal")
if has_nvidia and not has_amd_gpu:
lib_name = "cu12"
print("Auto-detected NVIDIA GPU. Using CUDA 12.x backend.")
elif has_amd_gpu and not has_nvidia:
lib_name = "amd"
print("Auto-detected AMD GPU. Using AMD backend.")
else:
gpu_lib_choices = {
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu12"},
"B": {"pretty": "AMD", "internal": "amd"},
}
print(
"WARNING: Auto-detection failed. "
"Please ensure you have either an NVIDIA GPU (with nvidia-smi) "
"or an AMD GPU (with rocm-smi or amd-smi) installed."
)
user_input = get_user_choice(
"Select your GPU. If you don't know, select Cuda 12.x (A)",
gpu_lib_choices,
)
lib_name = gpu_lib_choices.get(user_input, {}).get("internal")
# Write to start options
start_options["gpu_lib"] = lib_name