Add fallback to draft_rope_scale to 1.0

This commit is contained in:
DocShotgun
2023-12-05 18:51:36 -08:00
committed by GitHub
parent 39f7a2aabd
commit 3f2fcbcc45

View File

@@ -111,7 +111,7 @@ class ModelContainer:
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
if "draft_rope_scale" in kwargs: self.draft_config.scale_pos_emb = kwargs["draft_rope_scale"]
self.draft_config.scale_pos_emb = kwargs.get("draft_rope_scale") or 1.0
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
self.draft_config.max_seq_len = self.config.max_seq_len