mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Merge pull request #18 from DocShotgun/main
Add automatic NTK-aware alpha scaling to model
This commit is contained in:
25
model.py
25
model.py
@@ -70,11 +70,17 @@ class ModelContainer:
|
||||
self.config.model_dir = str(model_directory.resolve())
|
||||
self.config.prepare()
|
||||
|
||||
# Grab the base model's sequence length before overrides for rope calculations
|
||||
base_seq_len = self.config.max_seq_len
|
||||
|
||||
# Then override the max_seq_len if present
|
||||
if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"]
|
||||
if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"]
|
||||
if "rope_alpha" in kwargs: self.config.scale_alpha_value = kwargs["rope_alpha"]
|
||||
if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"]
|
||||
|
||||
# Automatically calculate rope alpha
|
||||
self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len)
|
||||
|
||||
if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"]
|
||||
if "low_mem" in kwargs and kwargs["low_mem"]:
|
||||
self.config.set_low_mem()
|
||||
|
||||
@@ -101,13 +107,7 @@ class ModelContainer:
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
if "draft_rope_alpha" in kwargs:
|
||||
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or 1
|
||||
else:
|
||||
ratio = self.config.max_seq_len / self.draft_config.max_seq_len
|
||||
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
|
||||
self.draft_config.scale_alpha_value = alpha
|
||||
|
||||
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
if "chunk_size" in kwargs:
|
||||
@@ -115,6 +115,13 @@ class ModelContainer:
|
||||
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
|
||||
|
||||
|
||||
def calculate_rope_alpha(self, base_seq_len):
|
||||
ratio = self.config.max_seq_len / base_seq_len
|
||||
|
||||
# Default to a 1 alpha if the sequence length is ever less than or equal to 1
|
||||
alpha = 1 if ratio <= 1.0 else -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
|
||||
return alpha
|
||||
|
||||
def get_model_path(self):
|
||||
model_path = pathlib.Path(self.config.model_dir)
|
||||
return model_path
|
||||
|
||||
Reference in New Issue
Block a user