diff --git a/model.py b/model.py index ed8d50f..2c79f5c 100644 --- a/model.py +++ b/model.py @@ -95,7 +95,7 @@ class ModelContainer: if self.draft_enabled: self.draft_config = ExLlamaV2Config() - draft_model_path = pathlib.Path(kwargs.get("draft_model_dir") or "models") + draft_model_path = pathlib.Path(draft_config.get("draft_model_dir") or "models") draft_model_path = draft_model_path / draft_model_name self.draft_config.model_dir = str(draft_model_path.resolve())