More extensive checks for paged mode support (#121)

* Model: More extensive checks for paged attention
Previously, TabbyAPI only checked for whether the user's hardware supports flash attention before deciding whether to enabled paged mode.
This adds checks for whether no_flash_attention is set, whether flash-attn is installed, and whether the installed version supports paged attention.

* Tree: Format

* Tree: Lint

* Model: Check GPU architecture first
Check GPU arch prior to checking whether flash attention 2 is installed
This commit is contained in:
DocShotgun
2024-06-05 00:33:21 -07:00
committed by GitHub
parent dbdcb38ad7
commit e391d84e40

View File

@@ -224,26 +224,72 @@ class ExllamaV2Container:
# Enable fasttensors loading if present
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
# Disable paged mode if the user's min GPU isn't supported (ampere and up)
min_compute_capability = min(
torch.cuda.get_device_capability(device=device_idx)[0]
for device_idx in gpu_device_list
)
# Compute capability < 8 is not supported by FA2
# AMD is also unsupported until ROCm updates its FA2 fork
if torch.version.hip or min_compute_capability < 8:
# Check whether the user's configuration supports paged attention
if self.config.no_flash_attn:
logger.warning(
"An unsupported GPU is found in this configuration. "
"Flash attention is disabled via config. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
"and features that rely on it (ex. CFG)."
)
self.config.no_flash_attn = True
self.paged = False
self.max_batch_size = 1
else:
try:
# Disable paged mode if the user's min GPU isn't supported (ampere+)
min_compute_capability = min(
torch.cuda.get_device_capability(device=device_idx)[0]
for device_idx in gpu_device_list
)
# Compute capability < 8 is not supported by FA2
# AMD is also unsupported until ROCm updates its FA2 fork
if torch.version.hip or min_compute_capability < 8:
logger.warning(
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
self.config.no_flash_attn = True
self.paged = False
self.max_batch_size = 1
else:
import flash_attn
flash_attn_ver = [
int(t) for t in flash_attn.__version__.split(".") if t.isdigit()
]
# Disable paged mode if the user's flash attention version < 2.5.7
if flash_attn_ver < [2, 5, 7]:
logger.warning(
"Flash attention version is older than 2.5.7 "
"which is required for paged attention. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"Please run start.bat or start.sh to update. \n"
"NOTE: Windows users must select CUDA 12.x to use FA2."
)
self.paged = False
self.max_batch_size = 1
except ModuleNotFoundError:
# Disable paged mode if flash attention is not installed
logger.warning(
"Flash attention is not installed. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG)."
"Please run start.bat or start.sh to install. \n"
"NOTE: Windows users must select CUDA 12.x to use FA2."
)
self.config.no_flash_attn = True
self.paged = False
self.max_batch_size = 1
# Try to set prompt template
self.prompt_template = self.find_prompt_template(