diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 31e0643..233e054 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -224,26 +224,72 @@ class ExllamaV2Container: # Enable fasttensors loading if present self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False) - # Disable paged mode if the user's min GPU isn't supported (ampere and up) - min_compute_capability = min( - torch.cuda.get_device_capability(device=device_idx)[0] - for device_idx in gpu_device_list - ) - - # Compute capability < 8 is not supported by FA2 - # AMD is also unsupported until ROCm updates its FA2 fork - if torch.version.hip or min_compute_capability < 8: + # Check whether the user's configuration supports paged attention + if self.config.no_flash_attn: logger.warning( - "An unsupported GPU is found in this configuration. " + "Flash attention is disabled via config. " "Switching to compatibility mode. \n" "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "To disable compatability mode, all GPUs must be ampere " - "(30 series) or newer. AMD GPUs are not supported." + "and features that rely on it (ex. CFG)." ) - self.config.no_flash_attn = True self.paged = False self.max_batch_size = 1 + else: + try: + # Disable paged mode if the user's min GPU isn't supported (ampere+) + min_compute_capability = min( + torch.cuda.get_device_capability(device=device_idx)[0] + for device_idx in gpu_device_list + ) + + # Compute capability < 8 is not supported by FA2 + # AMD is also unsupported until ROCm updates its FA2 fork + if torch.version.hip or min_compute_capability < 8: + logger.warning( + "An unsupported GPU is found in this configuration. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + "To disable compatability mode, all GPUs must be ampere " + "(30 series) or newer. AMD GPUs are not supported." + ) + self.config.no_flash_attn = True + self.paged = False + self.max_batch_size = 1 + else: + import flash_attn + + flash_attn_ver = [ + int(t) for t in flash_attn.__version__.split(".") if t.isdigit() + ] + + # Disable paged mode if the user's flash attention version < 2.5.7 + if flash_attn_ver < [2, 5, 7]: + logger.warning( + "Flash attention version is older than 2.5.7 " + "which is required for paged attention. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + "Please run start.bat or start.sh to update. \n" + "NOTE: Windows users must select CUDA 12.x to use FA2." + ) + self.paged = False + self.max_batch_size = 1 + + except ModuleNotFoundError: + # Disable paged mode if flash attention is not installed + logger.warning( + "Flash attention is not installed. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG)." + "Please run start.bat or start.sh to install. \n" + "NOTE: Windows users must select CUDA 12.x to use FA2." + ) + self.config.no_flash_attn = True + self.paged = False + self.max_batch_size = 1 # Try to set prompt template self.prompt_template = self.find_prompt_template(