diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 09da9a2..4745241 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -33,11 +33,7 @@ from backends.exllamav2.grammar import ( ExLlamaV2Grammar, clear_grammar_func_cache, ) -from backends.exllamav2.utils import ( - exllama_disabled_flash_attn, - hardware_supports_flash_attn, - supports_paged_attn, -) +from backends.exllamav2.utils import exllama_disabled_flash_attn from backends.exllamav2.vision import clear_image_embedding_cache from common.concurrency import iterate_in_threadpool from common.gen_logging import ( @@ -46,6 +42,7 @@ from common.gen_logging import ( log_prompt, log_response, ) +from common.hardware import hardware_supports_flash_attn from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest @@ -278,11 +275,20 @@ class ExllamaV2Container(BaseModelContainer): # Check whether the user's configuration supports flash/paged attention # Also check if exl2 has disabled flash attention - if ( - exllama_disabled_flash_attn(self.config.no_flash_attn) - or not hardware_supports_flash_attn(gpu_device_list) - or not supports_paged_attn() - ): + if exllama_disabled_flash_attn( + self.config.no_flash_attn + ) or not hardware_supports_flash_attn(gpu_device_list): + gpu_unsupported_message = ( + "An unsupported GPU is found in this configuration. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + "To disable compatability mode, all GPUs must be ampere " + "(30 series) or newer. AMD GPUs are not supported." + ) + + logger.warning(gpu_unsupported_message) + self.config.no_flash_attn = True if self.draft_config: self.draft_config.no_flash_attn = True diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 0fd1fcc..1648c62 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -1,74 +1,6 @@ -import platform -import torch -from packaging import version -from importlib.metadata import PackageNotFoundError, version as package_version from loguru import logger -def hardware_supports_flash_attn(gpu_device_list: list[int]): - """ - Check whether all GPUs in list support FA2 - - Compute capability < 8 is not supported by FA2 - AMD is also unsupported until ROCm updates its FA2 fork - """ - - # Logged message if unsupported - unsupported_message = ( - "An unsupported GPU is found in this configuration. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "To disable compatability mode, all GPUs must be ampere " - "(30 series) or newer. AMD GPUs are not supported." - ) - - min_compute_capability = min( - torch.cuda.get_device_capability(device=device_idx)[0] - for device_idx in gpu_device_list - ) - - if torch.version.hip or min_compute_capability < 8: - logger.warning(unsupported_message) - return False - else: - return True - - -def supports_paged_attn(): - """Check whether the user's flash-attn version supports paged mode""" - - # Logged message if unsupported - unsupported_message = ( - "Flash attention version >=2.5.7 " - "is required to use paged attention. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "Please upgrade your environment by running an update script " - "(update_scripts/" - f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" - "Or you can manually run a requirements update " - "using the following command:\n\n" - "For CUDA 12.1:\n" - "pip install --upgrade .[cu121]\n\n" - "NOTE: Windows users must use CUDA 12.x to use flash-attn." - ) - - required_version = version.parse("2.5.7") - try: - current_version = version.parse(package_version("flash-attn").split("+")[0]) - except PackageNotFoundError: - logger.warning(unsupported_message) - return False - - if current_version < required_version: - logger.warning(unsupported_message) - return False - else: - return True - - def exllama_disabled_flash_attn(no_flash_attn: bool): unsupported_message = ( "ExllamaV2 has disabled Flash Attention. \n" diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 06d3b29..ce27a85 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -27,13 +27,14 @@ from common.gen_logging import ( log_generation_params, log_metrics, ) +from common.hardware import hardware_supports_flash_attn from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate, find_prompt_template from common.transformers_utils import GenerationConfig from common.utils import coalesce, unwrap -from endpoints.core.types.model import ModelCard +from endpoints.core.types.model import ModelCard, ModelCardParameters class ExllamaV3Container(BaseModelContainer): @@ -59,11 +60,16 @@ class ExllamaV3Container(BaseModelContainer): tokenizer: Tokenizer config: Config generator: Optional[AsyncGenerator] = None + + # Class-specific vars gpu_split: List[float] | None = None gpu_split_auto: bool = True autosplit_reserve: List[float] = [96 / 1024] - max_seq_len: int use_tp: bool = False + max_seq_len: int = 4096 + cache_size: int = 4096 + chunk_size: int = 2048 + max_batch_size: Optional[int] = None # Required methods @classmethod @@ -90,8 +96,8 @@ class ExllamaV3Container(BaseModelContainer): self.model = Model.from_config(self.config) self.tokenizer = Tokenizer.from_config(self.config) - self.max_seq_len = kwargs.get("max_seq_len") - self.cache = Cache(self.model, max_num_tokens=self.max_seq_len) + # Fallback to 4096 since exl3 can't fetch from HF's config.json + self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) # Try to set prompt template self.prompt_template = await find_prompt_template( @@ -102,6 +108,7 @@ class ExllamaV3Container(BaseModelContainer): gpu_count = torch.cuda.device_count() gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) gpu_split = unwrap(kwargs.get("gpu_split"), None) + gpu_device_list = list(range(0, gpu_count)) # Set GPU split options if gpu_count == 1: @@ -114,6 +121,12 @@ class ExllamaV3Container(BaseModelContainer): # Enable manual GPU split if provided if gpu_split: self.gpu_split = gpu_split + + gpu_device_list = [ + device_idx + for device_idx, memory in enumerate(self.gpu_split) + if memory > 0 + ] elif gpu_split_auto and not self.use_tp: # Otherwise fallback to autosplit settings self.gpu_split_auto = gpu_split_auto @@ -126,10 +139,87 @@ class ExllamaV3Container(BaseModelContainer): self.autosplit_reserve = [ value / 1024 for value in autosplit_reserve_megabytes ] + + if not hardware_supports_flash_attn(gpu_device_list): + gpu_unsupported_message = ( + "Unable to run ExllamaV3 because an unsupported GPU is " + "found in this configuration. \n" + "All GPUs must be ampere " + "(30 series) or newer. AMD GPUs are not supported." + ) + + logger.warning(gpu_unsupported_message) + + raise RuntimeError(gpu_unsupported_message) + + # Cache + user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len) + self.cache_size = self.adjust_cache_size(user_cache_size) + self.cache = Cache(self.model, max_num_tokens=self.cache_size) + + # Max batch size + self.max_batch_size = kwargs.get("max_batch_size") + + # Make sure chunk size is >= 256, keep near or below max seq len + user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048) + self.chunk_size = self.adjust_chunk_size(user_chunk_size) + # TODO: speculative decoding return self + def adjust_cache_size(self, cache_size): + if cache_size < self.max_seq_len: + logger.warning( + f"The given cache_size ({cache_size}) is smaller than the " + "desired context length.\n" + "Overriding cache_size to max_seq_len. " + ) + + cache_size = self.max_seq_len + + # Enforce a multiple of 256 for cache size + # Overestimate to ensure that the cache isn't below max_seq_len + cache_remainder = cache_size % 256 + if cache_remainder != 0: + rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1)) + + logger.warning( + f"The given cache size ({cache_size}) is " + "not a multiple of 256.\n" + "Overriding cache_size with an overestimated value of " + f"{rounded_cache_size} tokens." + ) + + cache_size = rounded_cache_size + + # Warn user if cache size may be inadequate for CFG + if cache_size < 2 * self.max_seq_len: + logger.warning( + f"The given cache_size ({cache_size}) is less than 2 * max_seq_len " + "and may be too small for requests using CFG. \n" + "Ignore this warning if you do not plan on using CFG." + ) + + return cache_size + + def adjust_chunk_size(self, user_chunk_size: int): + chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1] + chunk_remainder = chunk_size % 256 + if chunk_remainder != 0: + rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1)) + + logger.warning( + f"The given chunk size ({chunk_size}) is " + "not a multiple of 256.\n" + "Overriding chunk_size with an overestimated value of " + f"{rounded_chunk_size} tokens." + ) + + chunk_size = rounded_chunk_size + + return chunk_size + def model_info(self) -> ModelCard: """ Returns a dictionary of the current model's configuration parameters. @@ -138,7 +228,25 @@ class ExllamaV3Container(BaseModelContainer): Model parameters provided by the backend """ - pass + model_params = ModelCardParameters( + max_seq_len=self.max_seq_len, + cache_size=self.cache_size, + max_batch_size=self.max_batch_size, + # cache_mode=self.cache_mode, + chunk_size=self.chunk_size, + use_vision=self.use_vision, + ) + + if self.prompt_template: + model_params.prompt_template = self.prompt_template.name + model_params.prompt_template_content = self.prompt_template.raw_template + + model_card = ModelCard( + id=self.model_dir.name, + parameters=model_params, + ) + + return model_card async def wait_for_jobs(self, skip_wait: bool = False): """ @@ -241,6 +349,7 @@ class ExllamaV3Container(BaseModelContainer): cache=self.cache, tokenizer=self.tokenizer, max_batch_size=self.max_batch_size, + max_chunk_size=self.chunk_size, ) # Update the state of the container var