Model: Change cache_size/max_seq_len behavior

- Cache size is now given only by the cache_size config option. Default is 4096 (user should always override to max out VRAM)
- max_seq_len, if not overridden in the config, will default to the model's config.json
- max_seq_len is reduced to be no larger than the cache
This commit is contained in:
turboderp
2025-10-05 22:15:27 +02:00
parent d672dc2137
commit 4235f98e83
5 changed files with 37 additions and 63 deletions

View File

@@ -238,7 +238,7 @@ class ExllamaV2Container(BaseModelContainer):
base_seq_len = hf_model.hf_config.max_position_embeddings
# Set the target seq len if present
target_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
target_seq_len = unwrap(kwargs.get("max_seq_len"), base_seq_len)
# Set the rope scale
self.config.scale_pos_emb = unwrap(
@@ -289,16 +289,7 @@ class ExllamaV2Container(BaseModelContainer):
# Set k/v cache size
# cache_size is only relevant when paged mode is enabled
if self.paged:
cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len)
if cache_size < self.config.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is smaller than the "
"desired context length.\n"
"Overriding cache_size to max_seq_len. "
)
cache_size = self.config.max_seq_len
cache_size = unwrap(kwargs.get("cache_size"), 4096)
# Enforce a multiple of 256 for cache size
# Overestimate to ensure that the cache isn't below max_seq_len
@@ -317,6 +308,13 @@ class ExllamaV2Container(BaseModelContainer):
cache_size = rounded_cache_size
if self.config.max_seq_len > cache_size:
logger.warning(
f"The given max_seq_len ({self.config.max_seq_len}) is larger than the "
f"cache size and will be limited to {cache_size} tokens."
)
self.config.max_seq_len = cache_size
# Warn user if cache size may be inadequate for CFG
if cache_size < 2 * self.config.max_seq_len:
logger.warning(

View File

@@ -105,12 +105,7 @@ class ExllamaV3Container(BaseModelContainer):
self = cls()
# Make sure ExllamaV3 is up to date
check_package_version("exllamav3", "0.0.4")
logger.warning(
"ExllamaV3 is currently in an alpha state. "
"Please note that all config options may not work."
)
check_package_version("exllamav3", "0.0.7")
self.model_dir = model_directory
self.hf_model = hf_model
@@ -131,9 +126,6 @@ class ExllamaV3Container(BaseModelContainer):
self.vision_model = None
self.use_vision = False
# Fallback to 4096 since exl3 can't fetch from HF's config.json
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
@@ -231,11 +223,15 @@ class ExllamaV3Container(BaseModelContainer):
raise RuntimeError(gpu_unsupported_message)
# Cache
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
user_cache_size = unwrap(kwargs.get("cache_size"), 4096)
self.cache_size = self.adjust_cache_size(user_cache_size)
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
self.cache = self.create_cache(self.cache_mode, self.model)
# Limit max_seq_len to prevent sequences larger than the cache
max_seq_len = unwrap(kwargs.get("max_seq_len"), hf_model.hf_config.max_position_embeddings)
self.max_seq_len = self.adjust_max_seq_len(max_seq_len)
# Draft cache
if self.use_draft_model:
# Set draft cache mode
@@ -274,21 +270,11 @@ class ExllamaV3Container(BaseModelContainer):
return self
def adjust_cache_size(self, cache_size):
if cache_size < self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is smaller than the "
"desired context length.\n"
"Overriding cache_size to max_seq_len. "
)
cache_size = self.max_seq_len
# Enforce a multiple of 256 for cache size
# Overestimate to ensure that the cache isn't below max_seq_len
cache_remainder = cache_size % 256
if cache_remainder != 0:
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
logger.warning(
f"The given cache size ({cache_size}) is "
"not a multiple of 256.\n"
@@ -298,22 +284,22 @@ class ExllamaV3Container(BaseModelContainer):
cache_size = rounded_cache_size
# Warn user if cache size may be inadequate for CFG
if cache_size < 2 * self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
"and may be too small for requests using CFG. \n"
"Ignore this warning if you do not plan on using CFG."
)
return cache_size
def adjust_chunk_size(self, user_chunk_size: int):
chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1]
chunk_remainder = chunk_size % 256
if chunk_remainder != 0:
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
def adjust_max_seq_len(self, max_seq_len):
if max_seq_len > self.cache_size:
logger.warning(
f"The given max_seq_len ({max_seq_len}) is larger than the cache size "
f"and will be limited to {self.cache_size} tokens."
)
max_seq_len = self.cache_size
return max_seq_len
def adjust_chunk_size(self, user_chunk_size: int):
chunk_size = max(256, user_chunk_size)
rounded_chunk_size = (chunk_size + 255) // 256 * 256
if chunk_size != rounded_chunk_size:
logger.warning(
f"The given chunk size ({chunk_size}) is "
"not a multiple of 256.\n"
@@ -950,24 +936,18 @@ class ExllamaV3Container(BaseModelContainer):
context_len = input_ids[0].size(dim=-1)
# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
max_tokens = unwrap(
params.max_tokens,
self.max_seq_len - context_len,
params.max_tokens if params.max_tokens > 0 else None,
self.max_seq_len - context_len - 1,
)
if max_tokens < 1:
logger.warning("max_tokens must be a positive integer, setting to 1.")
max_tokens = 1
# Determine if the negative context or the context length is bigger
context_to_check = context_len
# Check total length of prompt against max context length
if context_to_check > self.max_seq_len:
preamble = "Prompt"
if context_len > self.max_seq_len:
raise ValueError(
f"{preamble} length {context_to_check} is greater than "
f"Prompt length {context_len} is greater than "
f"max_seq_len {self.max_seq_len}"
)