mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
More extensive checks for paged mode support (#121)
* Model: More extensive checks for paged attention Previously, TabbyAPI only checked for whether the user's hardware supports flash attention before deciding whether to enabled paged mode. This adds checks for whether no_flash_attention is set, whether flash-attn is installed, and whether the installed version supports paged attention. * Tree: Format * Tree: Lint * Model: Check GPU architecture first Check GPU arch prior to checking whether flash attention 2 is installed
This commit is contained in:
@@ -224,26 +224,72 @@ class ExllamaV2Container:
|
||||
# Enable fasttensors loading if present
|
||||
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
|
||||
# Disable paged mode if the user's min GPU isn't supported (ampere and up)
|
||||
min_compute_capability = min(
|
||||
torch.cuda.get_device_capability(device=device_idx)[0]
|
||||
for device_idx in gpu_device_list
|
||||
)
|
||||
|
||||
# Compute capability < 8 is not supported by FA2
|
||||
# AMD is also unsupported until ROCm updates its FA2 fork
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
# Check whether the user's configuration supports paged attention
|
||||
if self.config.no_flash_attn:
|
||||
logger.warning(
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Flash attention is disabled via config. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
"and features that rely on it (ex. CFG)."
|
||||
)
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
else:
|
||||
try:
|
||||
# Disable paged mode if the user's min GPU isn't supported (ampere+)
|
||||
min_compute_capability = min(
|
||||
torch.cuda.get_device_capability(device=device_idx)[0]
|
||||
for device_idx in gpu_device_list
|
||||
)
|
||||
|
||||
# Compute capability < 8 is not supported by FA2
|
||||
# AMD is also unsupported until ROCm updates its FA2 fork
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
logger.warning(
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
else:
|
||||
import flash_attn
|
||||
|
||||
flash_attn_ver = [
|
||||
int(t) for t in flash_attn.__version__.split(".") if t.isdigit()
|
||||
]
|
||||
|
||||
# Disable paged mode if the user's flash attention version < 2.5.7
|
||||
if flash_attn_ver < [2, 5, 7]:
|
||||
logger.warning(
|
||||
"Flash attention version is older than 2.5.7 "
|
||||
"which is required for paged attention. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"Please run start.bat or start.sh to update. \n"
|
||||
"NOTE: Windows users must select CUDA 12.x to use FA2."
|
||||
)
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
||||
except ModuleNotFoundError:
|
||||
# Disable paged mode if flash attention is not installed
|
||||
logger.warning(
|
||||
"Flash attention is not installed. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG)."
|
||||
"Please run start.bat or start.sh to install. \n"
|
||||
"NOTE: Windows users must select CUDA 12.x to use FA2."
|
||||
)
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
|
||||
Reference in New Issue
Block a user