mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Model: Change FA2 and paged attention checks
The dynamic generator requires Flash attention 2.5.7 or higher to be installed. This is only supported on Nvidia's 30 series and higher. If a card is AMD or lower than the 30 series, switch to compatability mode which functions the same way as the older generator, except without parallel batching and any features that depend on it, such as CFG. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -3,8 +3,6 @@
|
||||
import gc
|
||||
import math
|
||||
import pathlib
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import torch
|
||||
import uuid
|
||||
@@ -57,10 +55,11 @@ class ExllamaV2Container:
|
||||
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
active_loras: List[ExLlamaV2Lora] = []
|
||||
paged: bool = True
|
||||
|
||||
# Internal config vars
|
||||
cache_mode: str = "FP16"
|
||||
use_cfg: bool = False
|
||||
max_batch_size: int = 20
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
# GPU split vars
|
||||
@@ -115,10 +114,6 @@ class ExllamaV2Container:
|
||||
available devices (default: True)
|
||||
'gpu_split' (list[float]): Allocation for weights and (some)
|
||||
tensors, per device
|
||||
'no_flash_attn' (bool): Turns off flash attention
|
||||
(increases vram usage) (default: False)
|
||||
'use_cfg" (bool): Enables CFG support. Disables flash attention
|
||||
(default: False)
|
||||
"""
|
||||
|
||||
self.quiet = quiet
|
||||
@@ -184,18 +179,9 @@ class ExllamaV2Container:
|
||||
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
|
||||
)
|
||||
|
||||
# Enable CFG if present
|
||||
self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
|
||||
|
||||
# Enable fasttensors loading if present
|
||||
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
|
||||
# Turn off flash attention if CFG is on
|
||||
# Workaround until batched FA2 is fixed in exllamav2 upstream
|
||||
# self.config.no_flash_attn = (
|
||||
# True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
|
||||
# )
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
@@ -345,7 +331,6 @@ class ExllamaV2Container:
|
||||
"cache_mode": self.cache_mode,
|
||||
"chunk_size": self.config.max_input_len,
|
||||
"num_experts_per_token": self.config.num_experts_per_token,
|
||||
"use_cfg": self.use_cfg,
|
||||
"prompt_template": self.prompt_template.name
|
||||
if self.prompt_template
|
||||
else None,
|
||||
@@ -420,10 +405,24 @@ class ExllamaV2Container:
|
||||
async for value in iterate_in_threadpool(model_load_generator):
|
||||
yield value
|
||||
|
||||
# TODO: Change these!
|
||||
# Set the max batch size and check if paged support is available
|
||||
max_batch_size = 1 if self.config.no_flash_attn else 20
|
||||
paged = not self.config.no_flash_attn
|
||||
# Disable paged mode if the user's min GPU is supported (ampere and above)
|
||||
min_compute_capability = min(
|
||||
set(
|
||||
[
|
||||
torch.cuda.get_device_capability(device=module.device_idx)[0]
|
||||
for module in self.model.modules
|
||||
if module.device_idx >= 0
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
logger.warning(
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. This disables parallel batching."
|
||||
)
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
||||
# Create async generator
|
||||
self.generator = ExLlamaV2DynamicGeneratorAsync(
|
||||
@@ -432,8 +431,8 @@ class ExllamaV2Container:
|
||||
draft_model=self.draft_model,
|
||||
draft_cache=self.draft_cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
paged=paged,
|
||||
max_batch_size=self.max_batch_size,
|
||||
paged=self.paged,
|
||||
)
|
||||
|
||||
# Clean up any extra vram usage from torch and cuda
|
||||
@@ -741,7 +740,7 @@ class ExllamaV2Container:
|
||||
cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0)
|
||||
negative_prompt = None
|
||||
if cfg_scale not in [None, 1.0]:
|
||||
if self.use_cfg:
|
||||
if self.paged:
|
||||
gen_settings.cfg_scale = cfg_scale
|
||||
|
||||
# If the negative prompt is empty, use the BOS token
|
||||
@@ -752,8 +751,8 @@ class ExllamaV2Container:
|
||||
prompts.append(negative_prompt)
|
||||
else:
|
||||
logger.warning(
|
||||
"CFG is currently disabled. "
|
||||
"If your GPU is supported, reload your model with use_cfg = True"
|
||||
"CFG is currently disabled because paged mode is disabled. "
|
||||
"Please use an ampere (30 series) or higher GPU for CFG support."
|
||||
)
|
||||
|
||||
gen_settings.token_repetition_penalty = unwrap(
|
||||
|
||||
Reference in New Issue
Block a user