Exl3: Add chunk size, cache size, and model info

Use the same algorithm for estimating and adjusting cache size based
on multiples of 256 and above max seq len.

Same applies for chunk size.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri
2025-04-30 23:58:27 -04:00
parent 303e2dde12
commit bdc5189a4b
3 changed files with 130 additions and 83 deletions

View File

@@ -27,13 +27,14 @@ from common.gen_logging import (
log_generation_params,
log_metrics,
)
from common.hardware import hardware_supports_flash_attn
from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig
from common.utils import coalesce, unwrap
from endpoints.core.types.model import ModelCard
from endpoints.core.types.model import ModelCard, ModelCardParameters
class ExllamaV3Container(BaseModelContainer):
@@ -59,11 +60,16 @@ class ExllamaV3Container(BaseModelContainer):
tokenizer: Tokenizer
config: Config
generator: Optional[AsyncGenerator] = None
# Class-specific vars
gpu_split: List[float] | None = None
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 / 1024]
max_seq_len: int
use_tp: bool = False
max_seq_len: int = 4096
cache_size: int = 4096
chunk_size: int = 2048
max_batch_size: Optional[int] = None
# Required methods
@classmethod
@@ -90,8 +96,8 @@ class ExllamaV3Container(BaseModelContainer):
self.model = Model.from_config(self.config)
self.tokenizer = Tokenizer.from_config(self.config)
self.max_seq_len = kwargs.get("max_seq_len")
self.cache = Cache(self.model, max_num_tokens=self.max_seq_len)
# Fallback to 4096 since exl3 can't fetch from HF's config.json
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
# Try to set prompt template
self.prompt_template = await find_prompt_template(
@@ -102,6 +108,7 @@ class ExllamaV3Container(BaseModelContainer):
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
gpu_split = unwrap(kwargs.get("gpu_split"), None)
gpu_device_list = list(range(0, gpu_count))
# Set GPU split options
if gpu_count == 1:
@@ -114,6 +121,12 @@ class ExllamaV3Container(BaseModelContainer):
# Enable manual GPU split if provided
if gpu_split:
self.gpu_split = gpu_split
gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.gpu_split)
if memory > 0
]
elif gpu_split_auto and not self.use_tp:
# Otherwise fallback to autosplit settings
self.gpu_split_auto = gpu_split_auto
@@ -126,10 +139,87 @@ class ExllamaV3Container(BaseModelContainer):
self.autosplit_reserve = [
value / 1024 for value in autosplit_reserve_megabytes
]
if not hardware_supports_flash_attn(gpu_device_list):
gpu_unsupported_message = (
"Unable to run ExllamaV3 because an unsupported GPU is "
"found in this configuration. \n"
"All GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
logger.warning(gpu_unsupported_message)
raise RuntimeError(gpu_unsupported_message)
# Cache
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
self.cache_size = self.adjust_cache_size(user_cache_size)
self.cache = Cache(self.model, max_num_tokens=self.cache_size)
# Max batch size
self.max_batch_size = kwargs.get("max_batch_size")
# Make sure chunk size is >= 256, keep near or below max seq len
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
self.chunk_size = self.adjust_chunk_size(user_chunk_size)
# TODO: speculative decoding
return self
def adjust_cache_size(self, cache_size):
if cache_size < self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is smaller than the "
"desired context length.\n"
"Overriding cache_size to max_seq_len. "
)
cache_size = self.max_seq_len
# Enforce a multiple of 256 for cache size
# Overestimate to ensure that the cache isn't below max_seq_len
cache_remainder = cache_size % 256
if cache_remainder != 0:
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
logger.warning(
f"The given cache size ({cache_size}) is "
"not a multiple of 256.\n"
"Overriding cache_size with an overestimated value of "
f"{rounded_cache_size} tokens."
)
cache_size = rounded_cache_size
# Warn user if cache size may be inadequate for CFG
if cache_size < 2 * self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
"and may be too small for requests using CFG. \n"
"Ignore this warning if you do not plan on using CFG."
)
return cache_size
def adjust_chunk_size(self, user_chunk_size: int):
chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1]
chunk_remainder = chunk_size % 256
if chunk_remainder != 0:
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
logger.warning(
f"The given chunk size ({chunk_size}) is "
"not a multiple of 256.\n"
"Overriding chunk_size with an overestimated value of "
f"{rounded_chunk_size} tokens."
)
chunk_size = rounded_chunk_size
return chunk_size
def model_info(self) -> ModelCard:
"""
Returns a dictionary of the current model's configuration parameters.
@@ -138,7 +228,25 @@ class ExllamaV3Container(BaseModelContainer):
Model parameters provided by the backend
"""
pass
model_params = ModelCardParameters(
max_seq_len=self.max_seq_len,
cache_size=self.cache_size,
max_batch_size=self.max_batch_size,
# cache_mode=self.cache_mode,
chunk_size=self.chunk_size,
use_vision=self.use_vision,
)
if self.prompt_template:
model_params.prompt_template = self.prompt_template.name
model_params.prompt_template_content = self.prompt_template.raw_template
model_card = ModelCard(
id=self.model_dir.name,
parameters=model_params,
)
return model_card
async def wait_for_jobs(self, skip_wait: bool = False):
"""
@@ -241,6 +349,7 @@ class ExllamaV3Container(BaseModelContainer):
cache=self.cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
max_chunk_size=self.chunk_size,
)
# Update the state of the container var