Files
tabbyAPI/backends/exllamav2/utils.py
kingbri c575105e41 ExllamaV2: Cleanup log placements
Move the large import errors into the check functions themselves.
This helps reduce the difficulty in interpreting where errors are
coming from.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-06-16 00:16:03 -04:00

97 lines
3.2 KiB
Python

from packaging import version
from importlib.metadata import PackageNotFoundError, version as package_version
from loguru import logger
import torch
def check_exllama_version():
"""Verifies the exllama version"""
required_version = version.parse("0.1.5")
current_version = version.parse(package_version("exllamav2").split("+")[0])
unsupported_message = (
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
f"or greater. Your current version is {current_version}.\n"
"Please upgrade your environment by running a start script "
"(start.bat or start.sh)\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"
"pip install --upgrade .[cu121]\n\n"
"For CUDA 11.8:\n"
"pip install --upgrade .[cu118]\n\n"
"For ROCm:\n"
"pip install --upgrade .[amd]\n\n"
)
if current_version < required_version:
raise SystemExit(unsupported_message)
else:
logger.info(f"ExllamaV2 version: {current_version}")
def hardware_supports_flash_attn(gpu_device_list: list[int]):
"""
Check whether all GPUs in list support FA2
Compute capability < 8 is not supported by FA2
AMD is also unsupported until ROCm updates its FA2 fork
"""
# Logged message if unsupported
unsupported_message = (
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
min_compute_capability = min(
torch.cuda.get_device_capability(device=device_idx)[0]
for device_idx in gpu_device_list
)
if torch.version.hip or min_compute_capability < 8:
logger.warning(unsupported_message)
return False
else:
return True
def supports_paged_attn():
"""Check whether the user's flash-attn version supports paged mode"""
# Logged message if unsupported
unsupported_message = (
"Flash attention version >=2.5.7 "
"is required to use paged attention. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"Please upgrade your environment by running a start script "
"(start.bat or start.sh)\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"
"pip install --upgrade .[cu121]\n\n"
"For CUDA 11.8:\n"
"pip install --upgrade .[cu118]\n\n"
"NOTE: Windows users must use CUDA 12.x to use flash-attn."
)
required_version = version.parse("2.5.7")
try:
current_version = version.parse(package_version("flash-attn").split("+")[0])
except PackageNotFoundError:
logger.warning(unsupported_message)
return False
if current_version < required_version:
logger.warning(unsupported_message)
return False
else:
return True