From e391d84e40b13763860f24a945b8d1a3d4dc266d Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Wed, 5 Jun 2024 00:33:21 -0700 Subject: [PATCH] More extensive checks for paged mode support (#121) * Model: More extensive checks for paged attention Previously, TabbyAPI only checked for whether the user's hardware supports flash attention before deciding whether to enabled paged mode. This adds checks for whether no_flash_attention is set, whether flash-attn is installed, and whether the installed version supports paged attention. * Tree: Format * Tree: Lint * Model: Check GPU architecture first Check GPU arch prior to checking whether flash attention 2 is installed --- backends/exllamav2/model.py | 74 ++++++++++++++++++++++++++++++------- 1 file changed, 60 insertions(+), 14 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 31e0643..233e054 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -224,26 +224,72 @@ class ExllamaV2Container: # Enable fasttensors loading if present self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False) - # Disable paged mode if the user's min GPU isn't supported (ampere and up) - min_compute_capability = min( - torch.cuda.get_device_capability(device=device_idx)[0] - for device_idx in gpu_device_list - ) - - # Compute capability < 8 is not supported by FA2 - # AMD is also unsupported until ROCm updates its FA2 fork - if torch.version.hip or min_compute_capability < 8: + # Check whether the user's configuration supports paged attention + if self.config.no_flash_attn: logger.warning( - "An unsupported GPU is found in this configuration. " + "Flash attention is disabled via config. " "Switching to compatibility mode. \n" "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "To disable compatability mode, all GPUs must be ampere " - "(30 series) or newer. AMD GPUs are not supported." + "and features that rely on it (ex. CFG)." ) - self.config.no_flash_attn = True self.paged = False self.max_batch_size = 1 + else: + try: + # Disable paged mode if the user's min GPU isn't supported (ampere+) + min_compute_capability = min( + torch.cuda.get_device_capability(device=device_idx)[0] + for device_idx in gpu_device_list + ) + + # Compute capability < 8 is not supported by FA2 + # AMD is also unsupported until ROCm updates its FA2 fork + if torch.version.hip or min_compute_capability < 8: + logger.warning( + "An unsupported GPU is found in this configuration. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + "To disable compatability mode, all GPUs must be ampere " + "(30 series) or newer. AMD GPUs are not supported." + ) + self.config.no_flash_attn = True + self.paged = False + self.max_batch_size = 1 + else: + import flash_attn + + flash_attn_ver = [ + int(t) for t in flash_attn.__version__.split(".") if t.isdigit() + ] + + # Disable paged mode if the user's flash attention version < 2.5.7 + if flash_attn_ver < [2, 5, 7]: + logger.warning( + "Flash attention version is older than 2.5.7 " + "which is required for paged attention. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + "Please run start.bat or start.sh to update. \n" + "NOTE: Windows users must select CUDA 12.x to use FA2." + ) + self.paged = False + self.max_batch_size = 1 + + except ModuleNotFoundError: + # Disable paged mode if flash attention is not installed + logger.warning( + "Flash attention is not installed. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG)." + "Please run start.bat or start.sh to install. \n" + "NOTE: Windows users must select CUDA 12.x to use FA2." + ) + self.config.no_flash_attn = True + self.paged = False + self.max_batch_size = 1 # Try to set prompt template self.prompt_template = self.find_prompt_template(