From 1c398b0be7257727409cfedeb4e9ece650836576 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Sat, 2 Dec 2023 21:02:29 -0800 Subject: [PATCH 1/3] Add automatic NTK-aware alpha scaling to model * enables automatic calculation of NTK-aware alpha scaling for models if the rope_alpha arg is not passed in the config, using the same formula used for draft models --- model.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/model.py b/model.py index 1eb90ae..b0cf3b8 100644 --- a/model.py +++ b/model.py @@ -69,10 +69,17 @@ class ModelContainer: self.config = ExLlamaV2Config() self.config.model_dir = str(model_directory.resolve()) self.config.prepare() + + base_seq_len = self.config.max_seq_len if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"] if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"] - if "rope_alpha" in kwargs: self.config.scale_alpha_value = kwargs["rope_alpha"] + if "rope_alpha" in kwargs: + self.config.scale_alpha_value = kwargs["rope_alpha"] + else: + ratio = self.config.max_seq_len / base_seq_len + alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 + self.config.scale_alpha_value = alpha if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"] if "low_mem" in kwargs and kwargs["low_mem"]: @@ -102,7 +109,7 @@ class ModelContainer: self.draft_config.prepare() if "draft_rope_alpha" in kwargs: - self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or 1 + self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") else: ratio = self.config.max_seq_len / self.draft_config.max_seq_len alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 From bd2c5d0d097fbccca95f5b10567ba72e18747591 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Sat, 2 Dec 2023 21:19:59 -0800 Subject: [PATCH 2/3] Force auto-alpha to 1.0 if config ctx == base ctx --- model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model.py b/model.py index b0cf3b8..564965b 100644 --- a/model.py +++ b/model.py @@ -79,6 +79,7 @@ class ModelContainer: else: ratio = self.config.max_seq_len / base_seq_len alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 + if ratio == 1: alpha = 1.0 self.config.scale_alpha_value = alpha if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"] @@ -113,6 +114,7 @@ class ModelContainer: else: ratio = self.config.max_seq_len / self.draft_config.max_seq_len alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 + if ratio == 1: alpha = 1.0 self.draft_config.scale_alpha_value = alpha self.draft_config.max_seq_len = self.config.max_seq_len From 27fc0c00692544be0ca7e7ae9e0ad4e814d12cf9 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 3 Dec 2023 01:05:09 -0500 Subject: [PATCH 3/3] Model: Cleanup and compartmentalize auto rope functions Also handle an edge case if ratio <= 1 since NTK scaling is only used for values > 1. Signed-off-by: kingbri --- model.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/model.py b/model.py index 564965b..25bfb25 100644 --- a/model.py +++ b/model.py @@ -69,20 +69,18 @@ class ModelContainer: self.config = ExLlamaV2Config() self.config.model_dir = str(model_directory.resolve()) self.config.prepare() - + + # Grab the base model's sequence length before overrides for rope calculations base_seq_len = self.config.max_seq_len + # Then override the max_seq_len if present if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"] if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"] - if "rope_alpha" in kwargs: - self.config.scale_alpha_value = kwargs["rope_alpha"] - else: - ratio = self.config.max_seq_len / base_seq_len - alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 - if ratio == 1: alpha = 1.0 - self.config.scale_alpha_value = alpha - if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"] + # Automatically calculate rope alpha + self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len) + + if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"] if "low_mem" in kwargs and kwargs["low_mem"]: self.config.set_low_mem() @@ -109,14 +107,7 @@ class ModelContainer: self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() - if "draft_rope_alpha" in kwargs: - self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") - else: - ratio = self.config.max_seq_len / self.draft_config.max_seq_len - alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 - if ratio == 1: alpha = 1.0 - self.draft_config.scale_alpha_value = alpha - + self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len) self.draft_config.max_seq_len = self.config.max_seq_len if "chunk_size" in kwargs: @@ -124,6 +115,13 @@ class ModelContainer: self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2 + def calculate_rope_alpha(self, base_seq_len): + ratio = self.config.max_seq_len / base_seq_len + + # Default to a 1 alpha if the sequence length is ever less than or equal to 1 + alpha = 1 if ratio <= 1.0 else -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 + return alpha + def get_model_path(self): model_path = pathlib.Path(self.config.model_dir) return model_path