From 76ffc7c4588c6ab5b3092c1284b9f527b103bd96 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Mon, 8 Dec 2025 23:52:02 +0000 Subject: [PATCH] [startup] auto-select GPU backend --- start.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/start.py b/start.py index 95bdd36..a05b736 100644 --- a/start.py +++ b/start.py @@ -47,17 +47,32 @@ def get_install_features(lib_name: str = None): possible_features = ["cu12", "amd"] if not lib_name: - # Ask the user for the GPU lib - gpu_lib_choices = { - "A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu12"}, - "B": {"pretty": "AMD", "internal": "amd"}, - } - user_input = get_user_choice( - "Select your GPU. If you don't know, select Cuda 12.x (A)", - gpu_lib_choices, - ) + has_nvidia = which("nvidia-smi") is not None + has_rocm = which("rocm-smi") is not None + has_amd = which("amd-smi") is not None + has_amd_gpu = has_rocm or has_amd - lib_name = gpu_lib_choices.get(user_input, {}).get("internal") + if has_nvidia and not has_amd_gpu: + lib_name = "cu12" + print("Auto-detected NVIDIA GPU. Using CUDA 12.x backend.") + elif has_amd_gpu and not has_nvidia: + lib_name = "amd" + print("Auto-detected AMD GPU. Using AMD backend.") + else: + gpu_lib_choices = { + "A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu12"}, + "B": {"pretty": "AMD", "internal": "amd"}, + } + print( + "WARNING: Auto-detection failed. " + "Please ensure you have either an NVIDIA GPU (with nvidia-smi) " + "or an AMD GPU (with rocm-smi or amd-smi) installed." + ) + user_input = get_user_choice( + "Select your GPU. If you don't know, select Cuda 12.x (A)", + gpu_lib_choices, + ) + lib_name = gpu_lib_choices.get(user_input, {}).get("internal") # Write to start options start_options["gpu_lib"] = lib_name