API: Fix model load endpoint with draft

Draft wasn't being parsed correctly with the new changes which removed
the draft_enabled bool. There's still some more work to be done with
returning exceptions.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-06 18:05:55 -05:00
parent 6a71890d45
commit f8e9e22c43
2 changed files with 10 additions and 5 deletions

View File

@@ -15,6 +15,7 @@ class ModelList(BaseModel):
class DraftModelLoadRequest(BaseModel):
draft_model_name: str
draft_rope_alpha: float = 1.0
draft_rope_scale: float = 1.0
class ModelLoadRequest(BaseModel):
name: str

14
main.py
View File

@@ -87,18 +87,20 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not data.name:
raise HTTPException(400, "model_name not found.")
# TODO: Move this to model_container
model_config = config.get("model") or {}
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = model_path / data.name
load_data = data.dict()
if data.draft and "draft" in model_config:
draft_config = model_config.get("draft") or {}
# TODO: Add API exception if draft directory isn't found
draft_config = model_config.get("draft") or {}
if data.draft:
if not data.draft.draft_model_name:
raise HTTPException(400, "draft_model_name was not found inside the draft object.")
load_data["draft_model_dir"] = draft_config.get("draft_model_dir") or "models"
load_data["draft"]["draft_model_dir"] = draft_config.get("draft_model_dir") or "models"
if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
@@ -108,7 +110,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
async def generator():
global model_container
model_type = "draft" if model_container.draft_enabled else "model"
model_type = "draft" if model_container.draft_config else "model"
load_status = model_container.load_gen(load_progress)
try:
@@ -131,7 +133,8 @@ async def load_model(request: Request, data: ModelLoadRequest):
yield get_sse_packet(response.json(ensure_ascii=False))
if model_container.draft_enabled:
# Switch to model progress if the draft model is loaded
if model_container.draft_config:
model_type = "model"
else:
loading_bar.next()
@@ -290,6 +293,7 @@ if __name__ == "__main__":
model_config = config.get("model") or {}
if "model_name" in model_config:
# TODO: Move this to model_container
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = model_path / model_config.get("model_name")