mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
ExllamaV2: Add max_seq_len empty case like ExllamaV3
Also remove the intermediate base_seq_len and target_seq_len variables to make code clearer. If paged mode is off, max_seq_len becomes the prime mover since batching is unavailable. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -233,33 +233,6 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
# Hardcode max output length to 16
|
||||
self.config.max_output_len = 16
|
||||
|
||||
# Grab the base model's sequence length before overrides for
|
||||
# rope calculations
|
||||
base_seq_len = hf_model.hf_config.max_position_embeddings
|
||||
|
||||
# Set the target seq len if present
|
||||
target_seq_len = unwrap(kwargs.get("max_seq_len"), base_seq_len)
|
||||
|
||||
# Set the rope scale
|
||||
self.config.scale_pos_emb = unwrap(
|
||||
kwargs.get("rope_scale"), self.config.scale_pos_emb
|
||||
)
|
||||
|
||||
# Sets rope alpha value.
|
||||
# Utilize the model's max_position_embeddings as a base value
|
||||
# Automatically calculate if unset or defined as an "auto" literal.
|
||||
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
|
||||
if rope_alpha == "auto":
|
||||
self.config.scale_alpha_value = calculate_rope_alpha(
|
||||
base_seq_len, target_seq_len
|
||||
)
|
||||
else:
|
||||
self.config.scale_alpha_value = rope_alpha
|
||||
|
||||
# Set the max seq len if specified
|
||||
if target_seq_len:
|
||||
self.config.max_seq_len = target_seq_len
|
||||
|
||||
# Set max batch size to the config override
|
||||
self.max_batch_size = unwrap(kwargs.get("max_batch_size"))
|
||||
|
||||
@@ -286,46 +259,36 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
self.max_batch_size = 1
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
|
||||
# Grab user-set max seq len
|
||||
user_max_seq_len = kwargs.get("max_seq_len")
|
||||
|
||||
# Set k/v cache size
|
||||
# cache_size is only relevant when paged mode is enabled
|
||||
if self.paged:
|
||||
cache_size = unwrap(kwargs.get("cache_size"), 4096)
|
||||
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
cache_remainder = cache_size % 256
|
||||
if cache_remainder != 0:
|
||||
rounded_cache_size = int(
|
||||
256 * ((cache_size - cache_remainder) / 256 + 1)
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"The given cache size ({cache_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
"Overriding cache_size with an overestimated value of "
|
||||
f"{rounded_cache_size} tokens."
|
||||
)
|
||||
|
||||
cache_size = rounded_cache_size
|
||||
|
||||
if self.config.max_seq_len > cache_size:
|
||||
logger.warning(
|
||||
f"The given max_seq_len ({self.config.max_seq_len}) is larger than "
|
||||
f"the cache size and will be limited to {cache_size} tokens."
|
||||
)
|
||||
self.config.max_seq_len = cache_size
|
||||
|
||||
# Warn user if cache size may be inadequate for CFG
|
||||
if cache_size < 2 * self.config.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
|
||||
"and may be too small for requests using CFG. \n"
|
||||
"Ignore this warning if you do not plan on using CFG."
|
||||
)
|
||||
|
||||
self.cache_size = cache_size
|
||||
user_cache_size = coalesce(kwargs.get("cache_size"), user_max_seq_len, 4096)
|
||||
self.cache_size = self.adjust_cache_size(user_cache_size)
|
||||
self.config.max_seq_len = self.adjust_max_seq_len(user_max_seq_len)
|
||||
else:
|
||||
self.cache_size = self.config.max_seq_len
|
||||
self.config.max_seq_len = unwrap(
|
||||
user_max_seq_len,
|
||||
min(hf_model.hf_config.max_position_embeddings, 4096)
|
||||
)
|
||||
|
||||
# Set the rope scale
|
||||
self.config.scale_pos_emb = unwrap(
|
||||
kwargs.get("rope_scale"), self.config.scale_pos_emb
|
||||
)
|
||||
|
||||
# Sets rope alpha value.
|
||||
# Utilize the model's max_position_embeddings as a base value
|
||||
# Automatically calculate if unset or defined as an "auto" literal.
|
||||
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
|
||||
if rope_alpha == "auto":
|
||||
self.config.scale_alpha_value = calculate_rope_alpha(
|
||||
hf_model.hf_config.max_position_embeddings, self.config.max_seq_len
|
||||
)
|
||||
else:
|
||||
self.config.scale_alpha_value = rope_alpha
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = await find_prompt_template(
|
||||
@@ -373,7 +336,8 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
|
||||
if draft_rope_alpha == "auto":
|
||||
self.draft_config.scale_alpha_value = calculate_rope_alpha(
|
||||
base_seq_len, self.draft_config.max_seq_len
|
||||
hf_model.hf_config.max_position_embeddings,
|
||||
self.draft_config.max_seq_len,
|
||||
)
|
||||
else:
|
||||
self.draft_config.scale_alpha_value = draft_rope_alpha
|
||||
@@ -400,6 +364,59 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
# Return the created instance
|
||||
return self
|
||||
|
||||
def adjust_cache_size(self, cache_size):
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
cache_remainder = cache_size % 256
|
||||
if cache_remainder != 0:
|
||||
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
|
||||
|
||||
logger.warning(
|
||||
f"The given cache size ({cache_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
"Overriding cache_size with an overestimated value of "
|
||||
f"{rounded_cache_size} tokens."
|
||||
)
|
||||
|
||||
cache_size = rounded_cache_size
|
||||
|
||||
if self.config.max_seq_len > cache_size:
|
||||
logger.warning(
|
||||
f"The given max_seq_len ({self.config.max_seq_len}) is larger than "
|
||||
f"the cache size and will be limited to {cache_size} tokens."
|
||||
)
|
||||
self.config.max_seq_len = cache_size
|
||||
|
||||
# Warn user if cache size may be inadequate for CFG
|
||||
if cache_size < 2 * self.config.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
|
||||
"and may be too small for requests using CFG. \n"
|
||||
"Ignore this warning if you do not plan on using CFG."
|
||||
)
|
||||
|
||||
return cache_size
|
||||
|
||||
def adjust_max_seq_len(self, max_seq_len):
|
||||
print(f"User max seq len {max_seq_len}")
|
||||
if not max_seq_len:
|
||||
default_max_seq_len = min(
|
||||
self.hf_model.hf_config.max_position_embeddings, self.cache_size
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"max_seq_len is undefined. Overriding to {default_max_seq_len} tokens."
|
||||
)
|
||||
max_seq_len = default_max_seq_len
|
||||
elif max_seq_len > self.cache_size:
|
||||
logger.warning(
|
||||
f"The given max_seq_len ({max_seq_len}) is larger than the cache size "
|
||||
f"and will be limited to {self.cache_size} tokens."
|
||||
)
|
||||
max_seq_len = self.cache_size
|
||||
|
||||
return max_seq_len
|
||||
|
||||
def model_info(self):
|
||||
draft_model_card: ModelCard = None
|
||||
if self.draft_config:
|
||||
|
||||
Reference in New Issue
Block a user