From 72b08624a39e097d5aaeee549cc27ceaf80c40c3 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 20 Mar 2024 00:46:30 -0400 Subject: [PATCH] Start: Update to use pyproject Signed-off-by: kingbri --- start.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/start.py b/start.py index 82655c7..458fe88 100644 --- a/start.py +++ b/start.py @@ -8,15 +8,15 @@ import subprocess from common.args import convert_args_to_dict, init_argparser -def get_requirements_file(): +def get_install_features(): """Fetches the appropriate requirements file depending on the GPU""" - requirements_name = "requirements-nowheel" + install_features = None ROCM_PATH = os.environ.get("ROCM_PATH") CUDA_PATH = os.environ.get("CUDA_PATH") # TODO: Check if the user has an AMD gpu on windows if ROCM_PATH: - requirements_name = "requirements-amd" + install_features = "amd" # Also override env vars for ROCm support on non-supported GPUs os.environ["ROCM_PATH"] = "/opt/rocm" @@ -25,11 +25,11 @@ def get_requirements_file(): elif CUDA_PATH: cuda_version = pathlib.Path(CUDA_PATH).name if "12" in cuda_version: - requirements_name = "requirements" + install_features = "cu121" elif "11" in cuda_version: - requirements_name = "requirements-cu118" + install_features = "cu118" - return requirements_name + return install_features def add_start_args(parser: argparse.ArgumentParser): @@ -60,10 +60,11 @@ if __name__ == "__main__": if args.ignore_upgrade: print("Ignoring pip dependency upgrade due to user request.") else: - requirements_file = ( - "requirements-nowheel" if args.nowheel else get_requirements_file() - ) - subprocess.run(["pip", "install", "-U", "-r", f"{requirements_file}.txt"]) + install_features = None if args.nowheel else get_install_features() + features = f"[{install_features}]" if install_features else "" + + # pip install .[features] + subprocess.run(["pip", "install", "-U", f".{features}"]) # Import entrypoint after installing all requirements from main import entrypoint