mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
Model: Fix flash-attn checks
If flash attention is already turned off by exllamaV2 itself, don't try creating a paged generator. Also condense all the redundant logic into one if statement. Also check arch_compat_overrides to see if flash attention should be disabled for a model arch (ex. Gemma 2) Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -5,7 +5,6 @@ import gc
|
||||
import math
|
||||
import pathlib
|
||||
import traceback
|
||||
from backends.exllamav2.utils import hardware_supports_flash_attn, supports_paged_attn
|
||||
import torch
|
||||
import uuid
|
||||
from exllamav2 import (
|
||||
@@ -28,6 +27,11 @@ from loguru import logger
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from backends.exllamav2.grammar import ExLlamaV2Grammar
|
||||
from backends.exllamav2.utils import (
|
||||
exllama_disabled_flash_attn,
|
||||
hardware_supports_flash_attn,
|
||||
supports_paged_attn,
|
||||
)
|
||||
from common.concurrency import iterate_in_threadpool
|
||||
from common.gen_logging import (
|
||||
log_generation_params,
|
||||
@@ -173,6 +177,9 @@ class ExllamaV2Container:
|
||||
|
||||
self.config.prepare()
|
||||
|
||||
# Check if the model arch is compatible with various exl2 features
|
||||
self.config.arch_compat_overrides()
|
||||
|
||||
# Then override the base_seq_len if present
|
||||
override_base_seq_len = kwargs.get("override_base_seq_len")
|
||||
if override_base_seq_len:
|
||||
@@ -200,13 +207,13 @@ class ExllamaV2Container:
|
||||
# Enable fasttensors loading if present
|
||||
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
|
||||
# Check whether the user's configuration supports paged attention
|
||||
if not hardware_supports_flash_attn(gpu_device_list):
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
elif not supports_paged_attn():
|
||||
# Check whether the user's configuration supports flash/paged attention
|
||||
# Also check if exl2 has disabled flash attention
|
||||
if (
|
||||
exllama_disabled_flash_attn(self.config.no_flash_attn)
|
||||
or not hardware_supports_flash_attn(gpu_device_list)
|
||||
or not supports_paged_attn()
|
||||
):
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
||||
Reference in New Issue
Block a user