Model + OAI: Fix parameter parsing

Rope alpha changes don't require removing the 1.0 default
from Rope scale.

Keep defaults when possible to avoid errors.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-17 14:28:18 -05:00
parent 3e57125025
commit 1d0bdfa77c
2 changed files with 12 additions and 4 deletions

View File

@@ -6,7 +6,7 @@ from gen_logging import LogConfig
class ModelCardParameters(BaseModel): class ModelCardParameters(BaseModel):
max_seq_len: Optional[int] = 4096 max_seq_len: Optional[int] = 4096
rope_scale: Optional[float] = 1.0 rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0 rope_alpha: Optional[float] = None
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
cache_mode: Optional[str] = "FP16" cache_mode: Optional[str] = "FP16"
draft: Optional['ModelCard'] = None draft: Optional['ModelCard'] = None
@@ -25,8 +25,8 @@ class ModelList(BaseModel):
class DraftModelLoadRequest(BaseModel): class DraftModelLoadRequest(BaseModel):
draft_model_name: str draft_model_name: str
draft_rope_scale: Optional[float] = 1.0
draft_rope_alpha: Optional[float] = None draft_rope_alpha: Optional[float] = None
draft_rope_scale: Optional[float] = None
# TODO: Unify this with ModelCardParams # TODO: Unify this with ModelCardParams
class ModelLoadRequest(BaseModel): class ModelLoadRequest(BaseModel):

View File

@@ -88,7 +88,10 @@ class ModelContainer:
self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0) self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)
# Automatically calculate rope alpha # Automatically calculate rope alpha
self.config.scale_alpha_value = unwrap(kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)) self.config.scale_alpha_value = unwrap(
kwargs.get("rope_alpha"),
self.calculate_rope_alpha(base_seq_len)
)
# Turn off flash attention? # Turn off flash attention?
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attn"), False) self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attn"), False)
@@ -124,7 +127,12 @@ class ModelContainer:
self.draft_config.prepare() self.draft_config.prepare()
self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0) self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0)
self.draft_config.scale_alpha_value = unwrap(draft_args.get("draft_rope_alpha"), self.calculate_rope_alpha(self.draft_config.max_seq_len))
# Automatically calculate draft rope alpha
self.draft_config.scale_alpha_value = unwrap(
draft_args.get("draft_rope_alpha"),
self.calculate_rope_alpha(self.draft_config.max_seq_len)
)
self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.max_seq_len = self.config.max_seq_len
if "chunk_size" in kwargs: if "chunk_size" in kwargs: