Config: Use an explicit "auto" value for rope_alpha

Using "auto" for rope alpha removes ambiguity on how to explicitly
enable automatic rope calculation. The same behavior of None -> auto
calculate still exists, but can be overwritten if a model's tabby_config.yml
includes `rope_alpha`.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-08-30 12:45:09 -04:00
committed by Brian Dashore
parent a96fa5f138
commit 4aebe8a2a5
5 changed files with 50 additions and 17 deletions

View File

@@ -249,10 +249,13 @@ class ExllamaV2Container:
kwargs.get("rope_scale"), self.config.scale_pos_emb
)
# Automatically calculate rope alpha
self.config.scale_alpha_value = unwrap(
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
)
# Sets rope alpha value.
# Automatically calculate if unset or defined as an "auto" literal.
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
if rope_alpha == "auto":
self.config.scale_alpha_value = self.calculate_rope_alpha(base_seq_len)
else:
self.config.scale_alpha_value = rope_alpha
# Enable fasttensors loading if present
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
@@ -344,16 +347,22 @@ class ExllamaV2Container:
# Set user-configured draft model values
if enable_draft:
self.draft_config.max_seq_len = self.config.max_seq_len
self.draft_config.scale_pos_emb = unwrap(
draft_args.get("draft_rope_scale"), 1.0
)
# Automatically calculate draft rope alpha
self.draft_config.scale_alpha_value = unwrap(
draft_args.get("draft_rope_alpha"),
self.calculate_rope_alpha(self.draft_config.max_seq_len),
)
self.draft_config.max_seq_len = self.config.max_seq_len
# Set draft rope alpha. Follows same behavior as model rope alpha.
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
if draft_rope_alpha == "auto":
self.draft_config.scale_alpha_value = self.calculate_rope_alpha(
self.draft_config.max_seq_len
)
else:
self.draft_config.scale_alpha_value = draft_rope_alpha
# Set draft cache mode
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
if chunk_size: