diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 22f9239..d4a65c2 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -210,8 +210,10 @@ class ExllamaV2Container: "To disable compatability mode, all GPUs must be ampere " "(30 series) or newer. AMD GPUs are not supported." ) + self.config.no_flash_attn = True self.paged = False self.max_batch_size = 1 + torch.backends.cuda.enable_flash_sdp(False) elif not supports_paged_attn(): logger.warning( "Flash attention version >=2.5.7 " @@ -229,8 +231,10 @@ class ExllamaV2Container: "pip install --upgrade .[cu118]\n\n" "NOTE: Windows users must use CUDA 12.x to use flash-attn." ) + self.config.no_flash_attn = True self.paged = False self.max_batch_size = 1 + torch.backends.cuda.enable_flash_sdp(False) # Set k/v cache size # cache_size is only relevant when paged mode is enabled @@ -331,6 +335,7 @@ class ExllamaV2Container: if enable_draft: self.draft_config = ExLlamaV2Config() + self.draft_config.no_flash_attn = self.config.no_flash_attn draft_model_path = pathlib.Path( unwrap(draft_args.get("draft_model_dir"), "models") )