ExllamaV2: Add max_seq_len empty case like ExllamaV3

Also remove the intermediate base_seq_len and target_seq_len variables
to make code clearer.

If paged mode is off, max_seq_len becomes the prime mover since batching
is unavailable.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri
2025-10-14 23:02:52 -04:00
parent 69a25d7fa6
commit fdb86f4c63

View File

@@ -233,33 +233,6 @@ class ExllamaV2Container(BaseModelContainer):
# Hardcode max output length to 16
self.config.max_output_len = 16
# Grab the base model's sequence length before overrides for
# rope calculations
base_seq_len = hf_model.hf_config.max_position_embeddings
# Set the target seq len if present
target_seq_len = unwrap(kwargs.get("max_seq_len"), base_seq_len)
# Set the rope scale
self.config.scale_pos_emb = unwrap(
kwargs.get("rope_scale"), self.config.scale_pos_emb
)
# Sets rope alpha value.
# Utilize the model's max_position_embeddings as a base value
# Automatically calculate if unset or defined as an "auto" literal.
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
if rope_alpha == "auto":
self.config.scale_alpha_value = calculate_rope_alpha(
base_seq_len, target_seq_len
)
else:
self.config.scale_alpha_value = rope_alpha
# Set the max seq len if specified
if target_seq_len:
self.config.max_seq_len = target_seq_len
# Set max batch size to the config override
self.max_batch_size = unwrap(kwargs.get("max_batch_size"))
@@ -286,46 +259,36 @@ class ExllamaV2Container(BaseModelContainer):
self.max_batch_size = 1
torch.backends.cuda.enable_flash_sdp(False)
# Grab user-set max seq len
user_max_seq_len = kwargs.get("max_seq_len")
# Set k/v cache size
# cache_size is only relevant when paged mode is enabled
if self.paged:
cache_size = unwrap(kwargs.get("cache_size"), 4096)
# Enforce a multiple of 256 for cache size
# Overestimate to ensure that the cache isn't below max_seq_len
cache_remainder = cache_size % 256
if cache_remainder != 0:
rounded_cache_size = int(
256 * ((cache_size - cache_remainder) / 256 + 1)
)
logger.warning(
f"The given cache size ({cache_size}) is "
"not a multiple of 256.\n"
"Overriding cache_size with an overestimated value of "
f"{rounded_cache_size} tokens."
)
cache_size = rounded_cache_size
if self.config.max_seq_len > cache_size:
logger.warning(
f"The given max_seq_len ({self.config.max_seq_len}) is larger than "
f"the cache size and will be limited to {cache_size} tokens."
)
self.config.max_seq_len = cache_size
# Warn user if cache size may be inadequate for CFG
if cache_size < 2 * self.config.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
"and may be too small for requests using CFG. \n"
"Ignore this warning if you do not plan on using CFG."
)
self.cache_size = cache_size
user_cache_size = coalesce(kwargs.get("cache_size"), user_max_seq_len, 4096)
self.cache_size = self.adjust_cache_size(user_cache_size)
self.config.max_seq_len = self.adjust_max_seq_len(user_max_seq_len)
else:
self.cache_size = self.config.max_seq_len
self.config.max_seq_len = unwrap(
user_max_seq_len,
min(hf_model.hf_config.max_position_embeddings, 4096)
)
# Set the rope scale
self.config.scale_pos_emb = unwrap(
kwargs.get("rope_scale"), self.config.scale_pos_emb
)
# Sets rope alpha value.
# Utilize the model's max_position_embeddings as a base value
# Automatically calculate if unset or defined as an "auto" literal.
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
if rope_alpha == "auto":
self.config.scale_alpha_value = calculate_rope_alpha(
hf_model.hf_config.max_position_embeddings, self.config.max_seq_len
)
else:
self.config.scale_alpha_value = rope_alpha
# Try to set prompt template
self.prompt_template = await find_prompt_template(
@@ -373,7 +336,8 @@ class ExllamaV2Container(BaseModelContainer):
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
if draft_rope_alpha == "auto":
self.draft_config.scale_alpha_value = calculate_rope_alpha(
base_seq_len, self.draft_config.max_seq_len
hf_model.hf_config.max_position_embeddings,
self.draft_config.max_seq_len,
)
else:
self.draft_config.scale_alpha_value = draft_rope_alpha
@@ -400,6 +364,59 @@ class ExllamaV2Container(BaseModelContainer):
# Return the created instance
return self
def adjust_cache_size(self, cache_size):
# Enforce a multiple of 256 for cache size
# Overestimate to ensure that the cache isn't below max_seq_len
cache_remainder = cache_size % 256
if cache_remainder != 0:
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
logger.warning(
f"The given cache size ({cache_size}) is "
"not a multiple of 256.\n"
"Overriding cache_size with an overestimated value of "
f"{rounded_cache_size} tokens."
)
cache_size = rounded_cache_size
if self.config.max_seq_len > cache_size:
logger.warning(
f"The given max_seq_len ({self.config.max_seq_len}) is larger than "
f"the cache size and will be limited to {cache_size} tokens."
)
self.config.max_seq_len = cache_size
# Warn user if cache size may be inadequate for CFG
if cache_size < 2 * self.config.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
"and may be too small for requests using CFG. \n"
"Ignore this warning if you do not plan on using CFG."
)
return cache_size
def adjust_max_seq_len(self, max_seq_len):
print(f"User max seq len {max_seq_len}")
if not max_seq_len:
default_max_seq_len = min(
self.hf_model.hf_config.max_position_embeddings, self.cache_size
)
logger.warning(
f"max_seq_len is undefined. Overriding to {default_max_seq_len} tokens."
)
max_seq_len = default_max_seq_len
elif max_seq_len > self.cache_size:
logger.warning(
f"The given max_seq_len ({max_seq_len}) is larger than the cache size "
f"and will be limited to {self.cache_size} tokens."
)
max_seq_len = self.cache_size
return max_seq_len
def model_info(self):
draft_model_card: ModelCard = None
if self.draft_config: