diff --git a/model.py b/model.py index 360cf88..af933cb 100644 --- a/model.py +++ b/model.py @@ -110,9 +110,9 @@ class ModelContainer: self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() - - self.draft_config.scale_pos_emb = kwargs.get("draft_rope_scale") or 1.0 - self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len) + + self.draft_config.scale_pos_emb = draft_config.get("draft_rope_scale") or 1.0 + self.draft_config.scale_alpha_value = draft_config.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len) self.draft_config.max_seq_len = self.config.max_seq_len if "chunk_size" in kwargs: