mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 10:11:39 +00:00
Exl3: Add chunk size, cache size, and model info
Use the same algorithm for estimating and adjusting cache size based on multiples of 256 and above max seq len. Same applies for chunk size. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -27,13 +27,14 @@ from common.gen_logging import (
|
||||
log_generation_params,
|
||||
log_metrics,
|
||||
)
|
||||
from common.hardware import hardware_supports_flash_attn
|
||||
from common.health import HealthManager
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate, find_prompt_template
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.utils import coalesce, unwrap
|
||||
from endpoints.core.types.model import ModelCard
|
||||
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
||||
|
||||
|
||||
class ExllamaV3Container(BaseModelContainer):
|
||||
@@ -59,11 +60,16 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
tokenizer: Tokenizer
|
||||
config: Config
|
||||
generator: Optional[AsyncGenerator] = None
|
||||
|
||||
# Class-specific vars
|
||||
gpu_split: List[float] | None = None
|
||||
gpu_split_auto: bool = True
|
||||
autosplit_reserve: List[float] = [96 / 1024]
|
||||
max_seq_len: int
|
||||
use_tp: bool = False
|
||||
max_seq_len: int = 4096
|
||||
cache_size: int = 4096
|
||||
chunk_size: int = 2048
|
||||
max_batch_size: Optional[int] = None
|
||||
|
||||
# Required methods
|
||||
@classmethod
|
||||
@@ -90,8 +96,8 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
self.model = Model.from_config(self.config)
|
||||
self.tokenizer = Tokenizer.from_config(self.config)
|
||||
|
||||
self.max_seq_len = kwargs.get("max_seq_len")
|
||||
self.cache = Cache(self.model, max_num_tokens=self.max_seq_len)
|
||||
# Fallback to 4096 since exl3 can't fetch from HF's config.json
|
||||
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = await find_prompt_template(
|
||||
@@ -102,6 +108,7 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
gpu_split = unwrap(kwargs.get("gpu_split"), None)
|
||||
gpu_device_list = list(range(0, gpu_count))
|
||||
|
||||
# Set GPU split options
|
||||
if gpu_count == 1:
|
||||
@@ -114,6 +121,12 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
# Enable manual GPU split if provided
|
||||
if gpu_split:
|
||||
self.gpu_split = gpu_split
|
||||
|
||||
gpu_device_list = [
|
||||
device_idx
|
||||
for device_idx, memory in enumerate(self.gpu_split)
|
||||
if memory > 0
|
||||
]
|
||||
elif gpu_split_auto and not self.use_tp:
|
||||
# Otherwise fallback to autosplit settings
|
||||
self.gpu_split_auto = gpu_split_auto
|
||||
@@ -126,10 +139,87 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
self.autosplit_reserve = [
|
||||
value / 1024 for value in autosplit_reserve_megabytes
|
||||
]
|
||||
|
||||
if not hardware_supports_flash_attn(gpu_device_list):
|
||||
gpu_unsupported_message = (
|
||||
"Unable to run ExllamaV3 because an unsupported GPU is "
|
||||
"found in this configuration. \n"
|
||||
"All GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
|
||||
logger.warning(gpu_unsupported_message)
|
||||
|
||||
raise RuntimeError(gpu_unsupported_message)
|
||||
|
||||
# Cache
|
||||
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
|
||||
self.cache_size = self.adjust_cache_size(user_cache_size)
|
||||
self.cache = Cache(self.model, max_num_tokens=self.cache_size)
|
||||
|
||||
# Max batch size
|
||||
self.max_batch_size = kwargs.get("max_batch_size")
|
||||
|
||||
# Make sure chunk size is >= 256, keep near or below max seq len
|
||||
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
|
||||
self.chunk_size = self.adjust_chunk_size(user_chunk_size)
|
||||
|
||||
# TODO: speculative decoding
|
||||
|
||||
return self
|
||||
|
||||
def adjust_cache_size(self, cache_size):
|
||||
if cache_size < self.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is smaller than the "
|
||||
"desired context length.\n"
|
||||
"Overriding cache_size to max_seq_len. "
|
||||
)
|
||||
|
||||
cache_size = self.max_seq_len
|
||||
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
cache_remainder = cache_size % 256
|
||||
if cache_remainder != 0:
|
||||
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
|
||||
|
||||
logger.warning(
|
||||
f"The given cache size ({cache_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
"Overriding cache_size with an overestimated value of "
|
||||
f"{rounded_cache_size} tokens."
|
||||
)
|
||||
|
||||
cache_size = rounded_cache_size
|
||||
|
||||
# Warn user if cache size may be inadequate for CFG
|
||||
if cache_size < 2 * self.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
|
||||
"and may be too small for requests using CFG. \n"
|
||||
"Ignore this warning if you do not plan on using CFG."
|
||||
)
|
||||
|
||||
return cache_size
|
||||
|
||||
def adjust_chunk_size(self, user_chunk_size: int):
|
||||
chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1]
|
||||
chunk_remainder = chunk_size % 256
|
||||
if chunk_remainder != 0:
|
||||
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
|
||||
|
||||
logger.warning(
|
||||
f"The given chunk size ({chunk_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
"Overriding chunk_size with an overestimated value of "
|
||||
f"{rounded_chunk_size} tokens."
|
||||
)
|
||||
|
||||
chunk_size = rounded_chunk_size
|
||||
|
||||
return chunk_size
|
||||
|
||||
def model_info(self) -> ModelCard:
|
||||
"""
|
||||
Returns a dictionary of the current model's configuration parameters.
|
||||
@@ -138,7 +228,25 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
Model parameters provided by the backend
|
||||
"""
|
||||
|
||||
pass
|
||||
model_params = ModelCardParameters(
|
||||
max_seq_len=self.max_seq_len,
|
||||
cache_size=self.cache_size,
|
||||
max_batch_size=self.max_batch_size,
|
||||
# cache_mode=self.cache_mode,
|
||||
chunk_size=self.chunk_size,
|
||||
use_vision=self.use_vision,
|
||||
)
|
||||
|
||||
if self.prompt_template:
|
||||
model_params.prompt_template = self.prompt_template.name
|
||||
model_params.prompt_template_content = self.prompt_template.raw_template
|
||||
|
||||
model_card = ModelCard(
|
||||
id=self.model_dir.name,
|
||||
parameters=model_params,
|
||||
)
|
||||
|
||||
return model_card
|
||||
|
||||
async def wait_for_jobs(self, skip_wait: bool = False):
|
||||
"""
|
||||
@@ -241,6 +349,7 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
cache=self.cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_chunk_size=self.chunk_size,
|
||||
)
|
||||
|
||||
# Update the state of the container var
|
||||
|
||||
Reference in New Issue
Block a user