Model: Change cache_size/max_seq_len behavior

- Cache size is now given only by the cache_size config option. Default is 4096 (user should always override to max out VRAM)
- max_seq_len, if not overridden in the config, will default to the model's config.json
- max_seq_len is reduced to be no larger than the cache
This commit is contained in:
turboderp
2025-10-05 22:15:27 +02:00
parent d672dc2137
commit 4235f98e83
5 changed files with 37 additions and 63 deletions

View File

@@ -238,7 +238,7 @@ class ExllamaV2Container(BaseModelContainer):
base_seq_len = hf_model.hf_config.max_position_embeddings
# Set the target seq len if present
target_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
target_seq_len = unwrap(kwargs.get("max_seq_len"), base_seq_len)
# Set the rope scale
self.config.scale_pos_emb = unwrap(
@@ -289,16 +289,7 @@ class ExllamaV2Container(BaseModelContainer):
# Set k/v cache size
# cache_size is only relevant when paged mode is enabled
if self.paged:
cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len)
if cache_size < self.config.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is smaller than the "
"desired context length.\n"
"Overriding cache_size to max_seq_len. "
)
cache_size = self.config.max_seq_len
cache_size = unwrap(kwargs.get("cache_size"), 4096)
# Enforce a multiple of 256 for cache size
# Overestimate to ensure that the cache isn't below max_seq_len
@@ -317,6 +308,13 @@ class ExllamaV2Container(BaseModelContainer):
cache_size = rounded_cache_size
if self.config.max_seq_len > cache_size:
logger.warning(
f"The given max_seq_len ({self.config.max_seq_len}) is larger than the "
f"cache size and will be limited to {cache_size} tokens."
)
self.config.max_seq_len = cache_size
# Warn user if cache size may be inadequate for CFG
if cache_size < 2 * self.config.max_seq_len:
logger.warning(

View File

@@ -105,12 +105,7 @@ class ExllamaV3Container(BaseModelContainer):
self = cls()
# Make sure ExllamaV3 is up to date
check_package_version("exllamav3", "0.0.4")
logger.warning(
"ExllamaV3 is currently in an alpha state. "
"Please note that all config options may not work."
)
check_package_version("exllamav3", "0.0.7")
self.model_dir = model_directory
self.hf_model = hf_model
@@ -131,9 +126,6 @@ class ExllamaV3Container(BaseModelContainer):
self.vision_model = None
self.use_vision = False
# Fallback to 4096 since exl3 can't fetch from HF's config.json
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
@@ -231,11 +223,15 @@ class ExllamaV3Container(BaseModelContainer):
raise RuntimeError(gpu_unsupported_message)
# Cache
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
user_cache_size = unwrap(kwargs.get("cache_size"), 4096)
self.cache_size = self.adjust_cache_size(user_cache_size)
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
self.cache = self.create_cache(self.cache_mode, self.model)
# Limit max_seq_len to prevent sequences larger than the cache
max_seq_len = unwrap(kwargs.get("max_seq_len"), hf_model.hf_config.max_position_embeddings)
self.max_seq_len = self.adjust_max_seq_len(max_seq_len)
# Draft cache
if self.use_draft_model:
# Set draft cache mode
@@ -274,21 +270,11 @@ class ExllamaV3Container(BaseModelContainer):
return self
def adjust_cache_size(self, cache_size):
if cache_size < self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is smaller than the "
"desired context length.\n"
"Overriding cache_size to max_seq_len. "
)
cache_size = self.max_seq_len
# Enforce a multiple of 256 for cache size
# Overestimate to ensure that the cache isn't below max_seq_len
cache_remainder = cache_size % 256
if cache_remainder != 0:
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
logger.warning(
f"The given cache size ({cache_size}) is "
"not a multiple of 256.\n"
@@ -298,22 +284,22 @@ class ExllamaV3Container(BaseModelContainer):
cache_size = rounded_cache_size
# Warn user if cache size may be inadequate for CFG
if cache_size < 2 * self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
"and may be too small for requests using CFG. \n"
"Ignore this warning if you do not plan on using CFG."
)
return cache_size
def adjust_chunk_size(self, user_chunk_size: int):
chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1]
chunk_remainder = chunk_size % 256
if chunk_remainder != 0:
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
def adjust_max_seq_len(self, max_seq_len):
if max_seq_len > self.cache_size:
logger.warning(
f"The given max_seq_len ({max_seq_len}) is larger than the cache size "
f"and will be limited to {self.cache_size} tokens."
)
max_seq_len = self.cache_size
return max_seq_len
def adjust_chunk_size(self, user_chunk_size: int):
chunk_size = max(256, user_chunk_size)
rounded_chunk_size = (chunk_size + 255) // 256 * 256
if chunk_size != rounded_chunk_size:
logger.warning(
f"The given chunk size ({chunk_size}) is "
"not a multiple of 256.\n"
@@ -950,24 +936,18 @@ class ExllamaV3Container(BaseModelContainer):
context_len = input_ids[0].size(dim=-1)
# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
max_tokens = unwrap(
params.max_tokens,
self.max_seq_len - context_len,
params.max_tokens if params.max_tokens > 0 else None,
self.max_seq_len - context_len - 1,
)
if max_tokens < 1:
logger.warning("max_tokens must be a positive integer, setting to 1.")
max_tokens = 1
# Determine if the negative context or the context length is bigger
context_to_check = context_len
# Check total length of prompt against max context length
if context_to_check > self.max_seq_len:
preamble = "Prompt"
if context_len > self.max_seq_len:
raise ValueError(
f"{preamble} length {context_to_check} is greater than "
f"Prompt length {context_len} is greater than "
f"max_seq_len {self.max_seq_len}"
)

View File

@@ -157,10 +157,8 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
# Override the max sequence length based on user
max_seq_len = kwargs.get("max_seq_len")
if max_seq_len == -1:
if max_seq_len == -1 or max_seq_len is None:
kwargs["max_seq_len"] = hf_model.hf_config.max_position_embeddings
elif max_seq_len is None:
kwargs["max_seq_len"] = 4096
# Create a new container and check if the right dependencies are installed
backend = unwrap(kwargs.get("backend"), detect_backend(hf_model))

View File

@@ -78,8 +78,7 @@ model:
# Options: exllamav2, exllamav3
backend:
# Max sequence length (default: 4096).
# Set to -1 to fetch from the model's config.json
# Max sequence length (default: fetch from the model's config.json).
max_seq_len:
# Load model with tensor parallelism.
@@ -124,9 +123,8 @@ model:
# For exllamav3, specify the pair k_bits,v_bits where k_bits and v_bits are integers from 2-8 (i.e. 8,8).
cache_mode: FP16
# Size of the prompt cache to allocate (default: max_seq_len).
# Must be a multiple of 256 and can't be less than max_seq_len.
# For CFG, set this to 2 * max_seq_len.
# Size of the key/value cache to allocate, in tokens (default: 4096).
# Must be a multiple of 256.
cache_size:
# Chunk size for prompt ingestion (default: 2048).

View File

@@ -85,7 +85,7 @@ class ModelLoadRequest(BaseModel):
examples=[4096],
)
cache_size: Optional[int] = Field(
description=("Number in tokens, must be greater than or equal to max_seq_len"),
description="Number in tokens, must be multiple of 256",
default=None,
examples=[4096],
)