mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Model: Add override base seq len
Some models (such as mistral and mixtral) set their base sequence length to 32k due to assumptions of support for sliding window attention. Therefore, add this parameter to override the base sequence length of a model which helps with auto-calculation of rope alpha. If auto-calculation of rope alpha isn't being used, the max_seq_len parameter works fine as is. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
13
model.py
13
model.py
@@ -85,14 +85,19 @@ class ModelContainer:
|
||||
self.config.max_seq_len = 4096
|
||||
self.config.prepare()
|
||||
|
||||
# Then override the max_seq_len if present
|
||||
override_max_seq_len = kwargs.get("max_seq_len")
|
||||
if override_max_seq_len:
|
||||
self.config.max_seq_len = kwargs.get("max_seq_len")
|
||||
# Then override the base_seq_len if present
|
||||
override_base_seq_len = kwargs.get("override_base_seq_len")
|
||||
if override_base_seq_len:
|
||||
self.config.max_seq_len = override_base_seq_len
|
||||
|
||||
# Grab the base model's sequence length before overrides for rope calculations
|
||||
base_seq_len = self.config.max_seq_len
|
||||
|
||||
# Set the target seq len if present
|
||||
target_max_seq_len = kwargs.get("max_seq_len")
|
||||
if target_max_seq_len:
|
||||
self.config.max_seq_len = target_max_seq_len
|
||||
|
||||
# Set the rope scale
|
||||
self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user