From fdb86f4c63003a1493289179b464d93a9e198950 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Tue, 14 Oct 2025 23:02:52 -0400 Subject: [PATCH] ExllamaV2: Add max_seq_len empty case like ExllamaV3 Also remove the intermediate base_seq_len and target_seq_len variables to make code clearer. If paged mode is off, max_seq_len becomes the prime mover since batching is unavailable. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- backends/exllamav2/model.py | 145 ++++++++++++++++++++---------------- 1 file changed, 81 insertions(+), 64 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 1d32309..ece7083 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -233,33 +233,6 @@ class ExllamaV2Container(BaseModelContainer): # Hardcode max output length to 16 self.config.max_output_len = 16 - # Grab the base model's sequence length before overrides for - # rope calculations - base_seq_len = hf_model.hf_config.max_position_embeddings - - # Set the target seq len if present - target_seq_len = unwrap(kwargs.get("max_seq_len"), base_seq_len) - - # Set the rope scale - self.config.scale_pos_emb = unwrap( - kwargs.get("rope_scale"), self.config.scale_pos_emb - ) - - # Sets rope alpha value. - # Utilize the model's max_position_embeddings as a base value - # Automatically calculate if unset or defined as an "auto" literal. - rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto") - if rope_alpha == "auto": - self.config.scale_alpha_value = calculate_rope_alpha( - base_seq_len, target_seq_len - ) - else: - self.config.scale_alpha_value = rope_alpha - - # Set the max seq len if specified - if target_seq_len: - self.config.max_seq_len = target_seq_len - # Set max batch size to the config override self.max_batch_size = unwrap(kwargs.get("max_batch_size")) @@ -286,46 +259,36 @@ class ExllamaV2Container(BaseModelContainer): self.max_batch_size = 1 torch.backends.cuda.enable_flash_sdp(False) + # Grab user-set max seq len + user_max_seq_len = kwargs.get("max_seq_len") + # Set k/v cache size # cache_size is only relevant when paged mode is enabled if self.paged: - cache_size = unwrap(kwargs.get("cache_size"), 4096) - - # Enforce a multiple of 256 for cache size - # Overestimate to ensure that the cache isn't below max_seq_len - cache_remainder = cache_size % 256 - if cache_remainder != 0: - rounded_cache_size = int( - 256 * ((cache_size - cache_remainder) / 256 + 1) - ) - - logger.warning( - f"The given cache size ({cache_size}) is " - "not a multiple of 256.\n" - "Overriding cache_size with an overestimated value of " - f"{rounded_cache_size} tokens." - ) - - cache_size = rounded_cache_size - - if self.config.max_seq_len > cache_size: - logger.warning( - f"The given max_seq_len ({self.config.max_seq_len}) is larger than " - f"the cache size and will be limited to {cache_size} tokens." - ) - self.config.max_seq_len = cache_size - - # Warn user if cache size may be inadequate for CFG - if cache_size < 2 * self.config.max_seq_len: - logger.warning( - f"The given cache_size ({cache_size}) is less than 2 * max_seq_len " - "and may be too small for requests using CFG. \n" - "Ignore this warning if you do not plan on using CFG." - ) - - self.cache_size = cache_size + user_cache_size = coalesce(kwargs.get("cache_size"), user_max_seq_len, 4096) + self.cache_size = self.adjust_cache_size(user_cache_size) + self.config.max_seq_len = self.adjust_max_seq_len(user_max_seq_len) else: - self.cache_size = self.config.max_seq_len + self.config.max_seq_len = unwrap( + user_max_seq_len, + min(hf_model.hf_config.max_position_embeddings, 4096) + ) + + # Set the rope scale + self.config.scale_pos_emb = unwrap( + kwargs.get("rope_scale"), self.config.scale_pos_emb + ) + + # Sets rope alpha value. + # Utilize the model's max_position_embeddings as a base value + # Automatically calculate if unset or defined as an "auto" literal. + rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto") + if rope_alpha == "auto": + self.config.scale_alpha_value = calculate_rope_alpha( + hf_model.hf_config.max_position_embeddings, self.config.max_seq_len + ) + else: + self.config.scale_alpha_value = rope_alpha # Try to set prompt template self.prompt_template = await find_prompt_template( @@ -373,7 +336,8 @@ class ExllamaV2Container(BaseModelContainer): draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto") if draft_rope_alpha == "auto": self.draft_config.scale_alpha_value = calculate_rope_alpha( - base_seq_len, self.draft_config.max_seq_len + hf_model.hf_config.max_position_embeddings, + self.draft_config.max_seq_len, ) else: self.draft_config.scale_alpha_value = draft_rope_alpha @@ -400,6 +364,59 @@ class ExllamaV2Container(BaseModelContainer): # Return the created instance return self + def adjust_cache_size(self, cache_size): + # Enforce a multiple of 256 for cache size + # Overestimate to ensure that the cache isn't below max_seq_len + cache_remainder = cache_size % 256 + if cache_remainder != 0: + rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1)) + + logger.warning( + f"The given cache size ({cache_size}) is " + "not a multiple of 256.\n" + "Overriding cache_size with an overestimated value of " + f"{rounded_cache_size} tokens." + ) + + cache_size = rounded_cache_size + + if self.config.max_seq_len > cache_size: + logger.warning( + f"The given max_seq_len ({self.config.max_seq_len}) is larger than " + f"the cache size and will be limited to {cache_size} tokens." + ) + self.config.max_seq_len = cache_size + + # Warn user if cache size may be inadequate for CFG + if cache_size < 2 * self.config.max_seq_len: + logger.warning( + f"The given cache_size ({cache_size}) is less than 2 * max_seq_len " + "and may be too small for requests using CFG. \n" + "Ignore this warning if you do not plan on using CFG." + ) + + return cache_size + + def adjust_max_seq_len(self, max_seq_len): + print(f"User max seq len {max_seq_len}") + if not max_seq_len: + default_max_seq_len = min( + self.hf_model.hf_config.max_position_embeddings, self.cache_size + ) + + logger.warning( + f"max_seq_len is undefined. Overriding to {default_max_seq_len} tokens." + ) + max_seq_len = default_max_seq_len + elif max_seq_len > self.cache_size: + logger.warning( + f"The given max_seq_len ({max_seq_len}) is larger than the cache size " + f"and will be limited to {self.cache_size} tokens." + ) + max_seq_len = self.cache_size + + return max_seq_len + def model_info(self): draft_model_card: ModelCard = None if self.draft_config: