mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-22 23:38:57 +00:00
Model: Fix flash-attn checks
If flash attention is already turned off by exllamaV2 itself, don't try creating a paged generator. Also condense all the redundant logic into one if statement. Also check arch_compat_overrides to see if flash attention should be disabled for a model arch (ex. Gemma 2) Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -5,7 +5,6 @@ import gc
|
|||||||
import math
|
import math
|
||||||
import pathlib
|
import pathlib
|
||||||
import traceback
|
import traceback
|
||||||
from backends.exllamav2.utils import hardware_supports_flash_attn, supports_paged_attn
|
|
||||||
import torch
|
import torch
|
||||||
import uuid
|
import uuid
|
||||||
from exllamav2 import (
|
from exllamav2 import (
|
||||||
@@ -28,6 +27,11 @@ from loguru import logger
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from backends.exllamav2.grammar import ExLlamaV2Grammar
|
from backends.exllamav2.grammar import ExLlamaV2Grammar
|
||||||
|
from backends.exllamav2.utils import (
|
||||||
|
exllama_disabled_flash_attn,
|
||||||
|
hardware_supports_flash_attn,
|
||||||
|
supports_paged_attn,
|
||||||
|
)
|
||||||
from common.concurrency import iterate_in_threadpool
|
from common.concurrency import iterate_in_threadpool
|
||||||
from common.gen_logging import (
|
from common.gen_logging import (
|
||||||
log_generation_params,
|
log_generation_params,
|
||||||
@@ -173,6 +177,9 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
self.config.prepare()
|
self.config.prepare()
|
||||||
|
|
||||||
|
# Check if the model arch is compatible with various exl2 features
|
||||||
|
self.config.arch_compat_overrides()
|
||||||
|
|
||||||
# Then override the base_seq_len if present
|
# Then override the base_seq_len if present
|
||||||
override_base_seq_len = kwargs.get("override_base_seq_len")
|
override_base_seq_len = kwargs.get("override_base_seq_len")
|
||||||
if override_base_seq_len:
|
if override_base_seq_len:
|
||||||
@@ -200,13 +207,13 @@ class ExllamaV2Container:
|
|||||||
# Enable fasttensors loading if present
|
# Enable fasttensors loading if present
|
||||||
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||||
|
|
||||||
# Check whether the user's configuration supports paged attention
|
# Check whether the user's configuration supports flash/paged attention
|
||||||
if not hardware_supports_flash_attn(gpu_device_list):
|
# Also check if exl2 has disabled flash attention
|
||||||
self.config.no_flash_attn = True
|
if (
|
||||||
self.paged = False
|
exllama_disabled_flash_attn(self.config.no_flash_attn)
|
||||||
self.max_batch_size = 1
|
or not hardware_supports_flash_attn(gpu_device_list)
|
||||||
torch.backends.cuda.enable_flash_sdp(False)
|
or not supports_paged_attn()
|
||||||
elif not supports_paged_attn():
|
):
|
||||||
self.config.no_flash_attn = True
|
self.config.no_flash_attn = True
|
||||||
self.paged = False
|
self.paged = False
|
||||||
self.max_batch_size = 1
|
self.max_batch_size = 1
|
||||||
|
|||||||
@@ -94,3 +94,18 @@ def supports_paged_attn():
|
|||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def exllama_disabled_flash_attn(no_flash_attn: bool):
|
||||||
|
unsupported_message = (
|
||||||
|
"ExllamaV2 has disabled Flash Attention. \n"
|
||||||
|
"Please see the above logs for warnings/errors. \n"
|
||||||
|
"Switching to compatibility mode. \n"
|
||||||
|
"This disables parallel batching "
|
||||||
|
"and features that rely on it (ex. CFG). \n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if no_flash_attn:
|
||||||
|
logger.warning(unsupported_message)
|
||||||
|
|
||||||
|
return no_flash_attn
|
||||||
|
|||||||
Reference in New Issue
Block a user