diff --git a/OAI/utils.py b/OAI/utils.py index 18e73ec..7421de3 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -77,9 +77,9 @@ def create_chat_completion_stream_chunk(const_id: str, return chunk -def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]): +def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None): - # Convert the draft model path to a pathlib path for equality comparisons + # Convert the provided draft model path to a pathlib path for equality comparisons if draft_model_path: draft_model_path = pathlib.Path(draft_model_path).resolve() diff --git a/main.py b/main.py index 2839282..955a5ff 100644 --- a/main.py +++ b/main.py @@ -57,10 +57,8 @@ app.add_middleware( @app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(): model_config = unwrap(config.get("model"), {}) - if "model_dir" in model_config: - model_path = pathlib.Path(model_config["model_dir"]) - else: - model_path = pathlib.Path("models") + model_dir = unwrap(model_config.get("model_dir"), "models") + model_path = pathlib.Path(model_dir) draft_config = unwrap(model_config.get("draft"), {}) draft_model_dir = draft_config.get("draft_model_dir") @@ -102,6 +100,18 @@ async def get_current_model(): return model_card +@app.get("/v1/model/draft/list") +async def list_draft_models(): + model_config = unwrap(config.get("model"), {}) + draft_config = unwrap(model_config.get("draft"), {}) + draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models") + draft_model_path = pathlib.Path(draft_model_dir) + + models = get_model_list(draft_model_path.resolve()) + print(models) + + return models + # Load model endpoint @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) async def load_model(request: Request, data: ModelLoadRequest):