From c575105e41940b26fcf39801f7bdf0e3370717cc Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 16 Jun 2024 00:15:05 -0400 Subject: [PATCH] ExllamaV2: Cleanup log placements Move the large import errors into the check functions themselves. This helps reduce the difficulty in interpreting where errors are coming from. Signed-off-by: kingbri --- backends/exllamav2/model.py | 24 -------------- backends/exllamav2/utils.py | 62 ++++++++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 38 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 19854d0..b28cfd8 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -202,35 +202,11 @@ class ExllamaV2Container: # Check whether the user's configuration supports paged attention if not hardware_supports_flash_attn(gpu_device_list): - logger.warning( - "An unsupported GPU is found in this configuration. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "To disable compatability mode, all GPUs must be ampere " - "(30 series) or newer. AMD GPUs are not supported." - ) self.config.no_flash_attn = True self.paged = False self.max_batch_size = 1 torch.backends.cuda.enable_flash_sdp(False) elif not supports_paged_attn(): - logger.warning( - "Flash attention version >=2.5.7 " - "is required to use paged attention. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "Please upgrade your environment by running a start script " - "(start.bat or start.sh)\n\n" - "Or you can manually run a requirements update " - "using the following command:\n\n" - "For CUDA 12.1:\n" - "pip install --upgrade .[cu121]\n\n" - "For CUDA 11.8:\n" - "pip install --upgrade .[cu118]\n\n" - "NOTE: Windows users must use CUDA 12.x to use flash-attn." - ) self.config.no_flash_attn = True self.paged = False self.max_batch_size = 1 diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 584165a..0417819 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -10,21 +10,23 @@ def check_exllama_version(): required_version = version.parse("0.1.5") current_version = version.parse(package_version("exllamav2").split("+")[0]) + unsupported_message = ( + f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} " + f"or greater. Your current version is {current_version}.\n" + "Please upgrade your environment by running a start script " + "(start.bat or start.sh)\n\n" + "Or you can manually run a requirements update " + "using the following command:\n\n" + "For CUDA 12.1:\n" + "pip install --upgrade .[cu121]\n\n" + "For CUDA 11.8:\n" + "pip install --upgrade .[cu118]\n\n" + "For ROCm:\n" + "pip install --upgrade .[amd]\n\n" + ) + if current_version < required_version: - raise SystemExit( - f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} " - f"or greater. Your current version is {current_version}.\n" - "Please upgrade your environment by running a start script " - "(start.bat or start.sh)\n\n" - "Or you can manually run a requirements update " - "using the following command:\n\n" - "For CUDA 12.1:\n" - "pip install --upgrade .[cu121]\n\n" - "For CUDA 11.8:\n" - "pip install --upgrade .[cu118]\n\n" - "For ROCm:\n" - "pip install --upgrade .[amd]\n\n" - ) + raise SystemExit(unsupported_message) else: logger.info(f"ExllamaV2 version: {current_version}") @@ -37,11 +39,23 @@ def hardware_supports_flash_attn(gpu_device_list: list[int]): AMD is also unsupported until ROCm updates its FA2 fork """ + # Logged message if unsupported + unsupported_message = ( + "An unsupported GPU is found in this configuration. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + "To disable compatability mode, all GPUs must be ampere " + "(30 series) or newer. AMD GPUs are not supported." + ) + min_compute_capability = min( torch.cuda.get_device_capability(device=device_idx)[0] for device_idx in gpu_device_list ) + if torch.version.hip or min_compute_capability < 8: + logger.warning(unsupported_message) return False else: return True @@ -50,13 +64,33 @@ def hardware_supports_flash_attn(gpu_device_list: list[int]): def supports_paged_attn(): """Check whether the user's flash-attn version supports paged mode""" + # Logged message if unsupported + unsupported_message = ( + "Flash attention version >=2.5.7 " + "is required to use paged attention. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + "Please upgrade your environment by running a start script " + "(start.bat or start.sh)\n\n" + "Or you can manually run a requirements update " + "using the following command:\n\n" + "For CUDA 12.1:\n" + "pip install --upgrade .[cu121]\n\n" + "For CUDA 11.8:\n" + "pip install --upgrade .[cu118]\n\n" + "NOTE: Windows users must use CUDA 12.x to use flash-attn." + ) + required_version = version.parse("2.5.7") try: current_version = version.parse(package_version("flash-attn").split("+")[0]) except PackageNotFoundError: + logger.warning(unsupported_message) return False if current_version < required_version: + logger.warning(unsupported_message) return False else: return True