From 1d0bdfa77cbccdfbedd4b75d1997d2acdcc0d906 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 17 Dec 2023 14:28:18 -0500 Subject: [PATCH] Model + OAI: Fix parameter parsing Rope alpha changes don't require removing the 1.0 default from Rope scale. Keep defaults when possible to avoid errors. Signed-off-by: kingbri --- OAI/types/model.py | 4 ++-- model.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/OAI/types/model.py b/OAI/types/model.py index ad7410d..dde43fc 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -6,7 +6,7 @@ from gen_logging import LogConfig class ModelCardParameters(BaseModel): max_seq_len: Optional[int] = 4096 rope_scale: Optional[float] = 1.0 - rope_alpha: Optional[float] = 1.0 + rope_alpha: Optional[float] = None prompt_template: Optional[str] = None cache_mode: Optional[str] = "FP16" draft: Optional['ModelCard'] = None @@ -25,8 +25,8 @@ class ModelList(BaseModel): class DraftModelLoadRequest(BaseModel): draft_model_name: str + draft_rope_scale: Optional[float] = 1.0 draft_rope_alpha: Optional[float] = None - draft_rope_scale: Optional[float] = None # TODO: Unify this with ModelCardParams class ModelLoadRequest(BaseModel): diff --git a/model.py b/model.py index 3316db5..6a4598d 100644 --- a/model.py +++ b/model.py @@ -88,7 +88,10 @@ class ModelContainer: self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0) # Automatically calculate rope alpha - self.config.scale_alpha_value = unwrap(kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)) + self.config.scale_alpha_value = unwrap( + kwargs.get("rope_alpha"), + self.calculate_rope_alpha(base_seq_len) + ) # Turn off flash attention? self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attn"), False) @@ -124,7 +127,12 @@ class ModelContainer: self.draft_config.prepare() self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0) - self.draft_config.scale_alpha_value = unwrap(draft_args.get("draft_rope_alpha"), self.calculate_rope_alpha(self.draft_config.max_seq_len)) + + # Automatically calculate draft rope alpha + self.draft_config.scale_alpha_value = unwrap( + draft_args.get("draft_rope_alpha"), + self.calculate_rope_alpha(self.draft_config.max_seq_len) + ) self.draft_config.max_seq_len = self.config.max_seq_len if "chunk_size" in kwargs: