From 156b74f3f0c98e4851e85c0b7aa80e62c814ddf6 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Sun, 9 Jun 2024 08:28:11 -0700 Subject: [PATCH] Revision to paged attention checks (#133) * Model: Clean up paged attention checks * Model: Move cache_size checks after paged attn checks Cache size is only relevant in paged mode * Model: Fix no_flash_attention * Model: Remove no_flash_attention Ability to use flash attention is auto-detected, so this flag is unneeded. Uninstall flash attention to disable it on supported hardware. --- backends/exllamav2/model.py | 156 ++++++++++++++--------------------- backends/exllamav2/utils.py | 36 +++++++- endpoints/OAI/types/model.py | 1 - 3 files changed, 99 insertions(+), 94 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 798ef2f..22f9239 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -5,6 +5,7 @@ import gc import math import pathlib import traceback +from backends.exllamav2.utils import hardware_supports_flash_attn, supports_paged_attn import torch import uuid from exllamav2 import ( @@ -196,112 +197,83 @@ class ExllamaV2Container: kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len) ) - # Set k/v cache size - cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len) - - if cache_size < self.config.max_seq_len: - logger.warning( - f"The given cache_size ({cache_size}) is smaller than the " - "desired context length.\n" - "Overriding cache_size to max_seq_len. " - ) - - cache_size = self.config.max_seq_len - - # Enforce a multiple of 256 for cache size - # Overestimate to ensure that the cache isn't below max_seq_len - cache_remainder = cache_size % 256 - if cache_remainder != 0: - rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1)) - - logger.warning( - f"The given cache size ({cache_size}) is " - "not a multiple of 256.\n" - "Overriding cache_size with an overestimated value of " - f"{rounded_cache_size} tokens." - ) - - cache_size = rounded_cache_size - - # Warn user if cache size may be inadequate for CFG - if cache_size < 2 * self.config.max_seq_len: - logger.warning( - f"The given cache_size ({cache_size}) is less than 2 * max_seq_len " - "and may be too small for requests using CFG. \n" - "Ignore this warning if you do not plan on using CFG." - ) - - self.cache_size = cache_size - # Enable fasttensors loading if present self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False) # Check whether the user's configuration supports paged attention - if self.config.no_flash_attn: + if not hardware_supports_flash_attn(gpu_device_list): logger.warning( - "Flash attention is disabled via config. " + "An unsupported GPU is found in this configuration. " "Switching to compatibility mode. \n" "This disables parallel batching " - "and features that rely on it (ex. CFG)." + "and features that rely on it (ex. CFG). \n" + "To disable compatability mode, all GPUs must be ampere " + "(30 series) or newer. AMD GPUs are not supported." + ) + self.paged = False + self.max_batch_size = 1 + elif not supports_paged_attn(): + logger.warning( + "Flash attention version >=2.5.7 " + "is required to use paged attention. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + "Please upgrade your environment by running a start script " + "(start.bat or start.sh)\n\n" + "Or you can manually run a requirements update " + "using the following command:\n\n" + "For CUDA 12.1:\n" + "pip install --upgrade .[cu121]\n\n" + "For CUDA 11.8:\n" + "pip install --upgrade .[cu118]\n\n" + "NOTE: Windows users must use CUDA 12.x to use flash-attn." ) self.paged = False self.max_batch_size = 1 - else: - try: - # Disable paged mode if the user's min GPU isn't supported (ampere+) - min_compute_capability = min( - torch.cuda.get_device_capability(device=device_idx)[0] - for device_idx in gpu_device_list - ) - # Compute capability < 8 is not supported by FA2 - # AMD is also unsupported until ROCm updates its FA2 fork - if torch.version.hip or min_compute_capability < 8: - logger.warning( - "An unsupported GPU is found in this configuration. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "To disable compatability mode, all GPUs must be ampere " - "(30 series) or newer. AMD GPUs are not supported." - ) - self.config.no_flash_attn = True - self.paged = False - self.max_batch_size = 1 - else: - import flash_attn + # Set k/v cache size + # cache_size is only relevant when paged mode is enabled + if self.paged: + cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len) - flash_attn_ver = [ - int(t) for t in flash_attn.__version__.split(".") if t.isdigit() - ] - - # Disable paged mode if the user's flash attention version < 2.5.7 - if flash_attn_ver < [2, 5, 7]: - logger.warning( - "Flash attention version is older than 2.5.7 " - "which is required for paged attention. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "Please run start.bat or start.sh to update. \n" - "NOTE: Windows users must select CUDA 12.x to use FA2." - ) - self.paged = False - self.max_batch_size = 1 - - except ModuleNotFoundError: - # Disable paged mode if flash attention is not installed + if cache_size < self.config.max_seq_len: logger.warning( - "Flash attention is not installed. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG)." - "Please run start.bat or start.sh to install. \n" - "NOTE: Windows users must select CUDA 12.x to use FA2." + f"The given cache_size ({cache_size}) is smaller than the " + "desired context length.\n" + "Overriding cache_size to max_seq_len. " ) - self.config.no_flash_attn = True - self.paged = False - self.max_batch_size = 1 + + cache_size = self.config.max_seq_len + + # Enforce a multiple of 256 for cache size + # Overestimate to ensure that the cache isn't below max_seq_len + cache_remainder = cache_size % 256 + if cache_remainder != 0: + rounded_cache_size = int( + 256 * ((cache_size - cache_remainder) / 256 + 1) + ) + + logger.warning( + f"The given cache size ({cache_size}) is " + "not a multiple of 256.\n" + "Overriding cache_size with an overestimated value of " + f"{rounded_cache_size} tokens." + ) + + cache_size = rounded_cache_size + + # Warn user if cache size may be inadequate for CFG + if cache_size < 2 * self.config.max_seq_len: + logger.warning( + f"The given cache_size ({cache_size}) is less than 2 * max_seq_len " + "and may be too small for requests using CFG. \n" + "Ignore this warning if you do not plan on using CFG." + ) + + self.cache_size = cache_size + else: + self.cache_size = self.config.max_seq_len # Try to set prompt template self.prompt_template = self.find_prompt_template( diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 0c0e2cf..584165a 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -1,6 +1,7 @@ from packaging import version -from importlib.metadata import version as package_version +from importlib.metadata import PackageNotFoundError, version as package_version from loguru import logger +import torch def check_exllama_version(): @@ -26,3 +27,36 @@ def check_exllama_version(): ) else: logger.info(f"ExllamaV2 version: {current_version}") + + +def hardware_supports_flash_attn(gpu_device_list: list[int]): + """ + Check whether all GPUs in list support FA2 + + Compute capability < 8 is not supported by FA2 + AMD is also unsupported until ROCm updates its FA2 fork + """ + + min_compute_capability = min( + torch.cuda.get_device_capability(device=device_idx)[0] + for device_idx in gpu_device_list + ) + if torch.version.hip or min_compute_capability < 8: + return False + else: + return True + + +def supports_paged_attn(): + """Check whether the user's flash-attn version supports paged mode""" + + required_version = version.parse("2.5.7") + try: + current_version = version.parse(package_version("flash-attn").split("+")[0]) + except PackageNotFoundError: + return False + + if current_version < required_version: + return False + else: + return True diff --git a/endpoints/OAI/types/model.py b/endpoints/OAI/types/model.py index 8847680..c549b49 100644 --- a/endpoints/OAI/types/model.py +++ b/endpoints/OAI/types/model.py @@ -94,7 +94,6 @@ class ModelLoadRequest(BaseModel): default=None, examples=[1.0], ) - no_flash_attention: Optional[bool] = False # low_mem: Optional[bool] = False cache_mode: Optional[str] = "FP16" chunk_size: Optional[int] = 2048