diff --git a/main.py b/main.py index b04d684..1948373 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,6 @@ import uvicorn import yaml -import pathlib +import pathlib, os from auth import check_admin_key, check_api_key, load_auth_keys from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware @@ -75,6 +75,8 @@ async def get_current_model(): # Load model endpoint @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) async def load_model(data: ModelLoadRequest): + global model_container + if model_container and model_container.model: raise HTTPException(400, "A model is already loaded! Please unload it first.") @@ -94,13 +96,15 @@ async def load_model(data: ModelLoadRequest): load_data["draft_model_dir"] = draft_config.get("draft_model_dir") or "models" + if not model_path.exists(): + raise HTTPException(400, "model_path does not exist. Check model_name?") + + model_container = ModelContainer(model_path.resolve(), False, **load_data) + def generator(): - global model_container - - model_container = ModelContainer(model_path.resolve(), False, **load_data) model_type = "draft" if model_container.draft_enabled else "model" - load_status = model_container.load_gen(load_progress) + for (module, modules) in load_status: if module == 0: loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)