From f8e9e22c435ed0ae7d3f5d3a390a662d51e470f9 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 6 Dec 2023 18:05:55 -0500 Subject: [PATCH] API: Fix model load endpoint with draft Draft wasn't being parsed correctly with the new changes which removed the draft_enabled bool. There's still some more work to be done with returning exceptions. Signed-off-by: kingbri --- OAI/types/model.py | 1 + main.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/OAI/types/model.py b/OAI/types/model.py index f82daf2..365d8f9 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -15,6 +15,7 @@ class ModelList(BaseModel): class DraftModelLoadRequest(BaseModel): draft_model_name: str draft_rope_alpha: float = 1.0 + draft_rope_scale: float = 1.0 class ModelLoadRequest(BaseModel): name: str diff --git a/main.py b/main.py index e5b9eb1..c9ddc5b 100644 --- a/main.py +++ b/main.py @@ -87,18 +87,20 @@ async def load_model(request: Request, data: ModelLoadRequest): if not data.name: raise HTTPException(400, "model_name not found.") + # TODO: Move this to model_container model_config = config.get("model") or {} model_path = pathlib.Path(model_config.get("model_dir") or "models") model_path = model_path / data.name load_data = data.dict() - if data.draft and "draft" in model_config: - draft_config = model_config.get("draft") or {} + # TODO: Add API exception if draft directory isn't found + draft_config = model_config.get("draft") or {} + if data.draft: if not data.draft.draft_model_name: raise HTTPException(400, "draft_model_name was not found inside the draft object.") - load_data["draft_model_dir"] = draft_config.get("draft_model_dir") or "models" + load_data["draft"]["draft_model_dir"] = draft_config.get("draft_model_dir") or "models" if not model_path.exists(): raise HTTPException(400, "model_path does not exist. Check model_name?") @@ -108,7 +110,7 @@ async def load_model(request: Request, data: ModelLoadRequest): async def generator(): global model_container - model_type = "draft" if model_container.draft_enabled else "model" + model_type = "draft" if model_container.draft_config else "model" load_status = model_container.load_gen(load_progress) try: @@ -131,7 +133,8 @@ async def load_model(request: Request, data: ModelLoadRequest): yield get_sse_packet(response.json(ensure_ascii=False)) - if model_container.draft_enabled: + # Switch to model progress if the draft model is loaded + if model_container.draft_config: model_type = "model" else: loading_bar.next() @@ -290,6 +293,7 @@ if __name__ == "__main__": model_config = config.get("model") or {} if "model_name" in model_config: + # TODO: Move this to model_container model_path = pathlib.Path(model_config.get("model_dir") or "models") model_path = model_path / model_config.get("model_name")