mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
ExllamaV2: Cleanup log placements
Move the large import errors into the check functions themselves. This helps reduce the difficulty in interpreting where errors are coming from. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -202,35 +202,11 @@ class ExllamaV2Container:
|
||||
|
||||
# Check whether the user's configuration supports paged attention
|
||||
if not hardware_supports_flash_attn(gpu_device_list):
|
||||
logger.warning(
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
elif not supports_paged_attn():
|
||||
logger.warning(
|
||||
"Flash attention version >=2.5.7 "
|
||||
"is required to use paged attention. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"Please upgrade your environment by running a start script "
|
||||
"(start.bat or start.sh)\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade .[cu121]\n\n"
|
||||
"For CUDA 11.8:\n"
|
||||
"pip install --upgrade .[cu118]\n\n"
|
||||
"NOTE: Windows users must use CUDA 12.x to use flash-attn."
|
||||
)
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
||||
@@ -10,21 +10,23 @@ def check_exllama_version():
|
||||
required_version = version.parse("0.1.5")
|
||||
current_version = version.parse(package_version("exllamav2").split("+")[0])
|
||||
|
||||
unsupported_message = (
|
||||
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
|
||||
f"or greater. Your current version is {current_version}.\n"
|
||||
"Please upgrade your environment by running a start script "
|
||||
"(start.bat or start.sh)\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade .[cu121]\n\n"
|
||||
"For CUDA 11.8:\n"
|
||||
"pip install --upgrade .[cu118]\n\n"
|
||||
"For ROCm:\n"
|
||||
"pip install --upgrade .[amd]\n\n"
|
||||
)
|
||||
|
||||
if current_version < required_version:
|
||||
raise SystemExit(
|
||||
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
|
||||
f"or greater. Your current version is {current_version}.\n"
|
||||
"Please upgrade your environment by running a start script "
|
||||
"(start.bat or start.sh)\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade .[cu121]\n\n"
|
||||
"For CUDA 11.8:\n"
|
||||
"pip install --upgrade .[cu118]\n\n"
|
||||
"For ROCm:\n"
|
||||
"pip install --upgrade .[amd]\n\n"
|
||||
)
|
||||
raise SystemExit(unsupported_message)
|
||||
else:
|
||||
logger.info(f"ExllamaV2 version: {current_version}")
|
||||
|
||||
@@ -37,11 +39,23 @@ def hardware_supports_flash_attn(gpu_device_list: list[int]):
|
||||
AMD is also unsupported until ROCm updates its FA2 fork
|
||||
"""
|
||||
|
||||
# Logged message if unsupported
|
||||
unsupported_message = (
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
|
||||
min_compute_capability = min(
|
||||
torch.cuda.get_device_capability(device=device_idx)[0]
|
||||
for device_idx in gpu_device_list
|
||||
)
|
||||
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
logger.warning(unsupported_message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -50,13 +64,33 @@ def hardware_supports_flash_attn(gpu_device_list: list[int]):
|
||||
def supports_paged_attn():
|
||||
"""Check whether the user's flash-attn version supports paged mode"""
|
||||
|
||||
# Logged message if unsupported
|
||||
unsupported_message = (
|
||||
"Flash attention version >=2.5.7 "
|
||||
"is required to use paged attention. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"Please upgrade your environment by running a start script "
|
||||
"(start.bat or start.sh)\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade .[cu121]\n\n"
|
||||
"For CUDA 11.8:\n"
|
||||
"pip install --upgrade .[cu118]\n\n"
|
||||
"NOTE: Windows users must use CUDA 12.x to use flash-attn."
|
||||
)
|
||||
|
||||
required_version = version.parse("2.5.7")
|
||||
try:
|
||||
current_version = version.parse(package_version("flash-attn").split("+")[0])
|
||||
except PackageNotFoundError:
|
||||
logger.warning(unsupported_message)
|
||||
return False
|
||||
|
||||
if current_version < required_version:
|
||||
logger.warning(unsupported_message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user