mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
API: Fix model load endpoint with draft
Draft wasn't being parsed correctly with the new changes which removed the draft_enabled bool. There's still some more work to be done with returning exceptions. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -15,6 +15,7 @@ class ModelList(BaseModel):
|
||||
class DraftModelLoadRequest(BaseModel):
|
||||
draft_model_name: str
|
||||
draft_rope_alpha: float = 1.0
|
||||
draft_rope_scale: float = 1.0
|
||||
|
||||
class ModelLoadRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
14
main.py
14
main.py
@@ -87,18 +87,20 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
||||
if not data.name:
|
||||
raise HTTPException(400, "model_name not found.")
|
||||
|
||||
# TODO: Move this to model_container
|
||||
model_config = config.get("model") or {}
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_path = model_path / data.name
|
||||
|
||||
load_data = data.dict()
|
||||
if data.draft and "draft" in model_config:
|
||||
draft_config = model_config.get("draft") or {}
|
||||
|
||||
# TODO: Add API exception if draft directory isn't found
|
||||
draft_config = model_config.get("draft") or {}
|
||||
if data.draft:
|
||||
if not data.draft.draft_model_name:
|
||||
raise HTTPException(400, "draft_model_name was not found inside the draft object.")
|
||||
|
||||
load_data["draft_model_dir"] = draft_config.get("draft_model_dir") or "models"
|
||||
load_data["draft"]["draft_model_dir"] = draft_config.get("draft_model_dir") or "models"
|
||||
|
||||
if not model_path.exists():
|
||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||
@@ -108,7 +110,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
||||
async def generator():
|
||||
global model_container
|
||||
|
||||
model_type = "draft" if model_container.draft_enabled else "model"
|
||||
model_type = "draft" if model_container.draft_config else "model"
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
|
||||
try:
|
||||
@@ -131,7 +133,8 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
||||
|
||||
yield get_sse_packet(response.json(ensure_ascii=False))
|
||||
|
||||
if model_container.draft_enabled:
|
||||
# Switch to model progress if the draft model is loaded
|
||||
if model_container.draft_config:
|
||||
model_type = "model"
|
||||
else:
|
||||
loading_bar.next()
|
||||
@@ -290,6 +293,7 @@ if __name__ == "__main__":
|
||||
|
||||
model_config = config.get("model") or {}
|
||||
if "model_name" in model_config:
|
||||
# TODO: Move this to model_container
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_path = model_path / model_config.get("model_name")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user