OAI: Add API-based model loading/unloading and auth routes

Models can be loaded and unloaded via the API. Also add authentication
to use the API and for administrator tasks.

Both types of authorization use different keys.

Also fix the unload function to properly free all used vram.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-14 01:17:19 -05:00
parent 47343e2f1a
commit b625bface9
11 changed files with 195 additions and 55 deletions

107
main.py
View File

@@ -1,30 +1,86 @@
import uvicorn
import yaml
from fastapi import FastAPI, Request
import pathlib
from auth import check_admin_key, check_api_key, load_auth_keys
from fastapi import FastAPI, Request, HTTPException, Depends
from model import ModelContainer
from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse
from OAI.models.completions import CompletionRequest, CompletionResponse
from OAI.models.models import ModelCard, ModelList
from OAI.types.completions import CompletionRequest, CompletionResponse
from OAI.types.models import ModelCard, ModelList, ModelLoadRequest, ModelLoadResponse
from OAI.utils import create_completion_response, get_model_list
from typing import Optional
from utils import load_progress
app = FastAPI()
# Initialize a model container. This can be undefined at any period of time
model_container: ModelContainer = None
# Globally scoped variables. Undefined until initalized in main
model_container: Optional[ModelContainer] = None
config: Optional[dict] = None
@app.get("/v1/models")
@app.get("/v1/model/list")
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
models = get_model_list(model_container.get_model_path())
model_config = config["model"]
models = get_model_list(pathlib.Path(model_config["model_dir"] or "models"))
return models.model_dump_json()
@app.get("/v1/model")
@app.get("/v1/model", dependencies=[Depends(check_api_key)])
async def get_current_model():
return ModelCard(id = model_container.get_model_path().name)
if model_container is None or model_container.model is None:
return HTTPException(400, "No models are loaded.")
@app.post("/v1/completions", response_class=CompletionResponse)
model_card = ModelCard(id=model_container.get_model_path().name)
return model_card.model_dump_json()
@app.post("/v1/model/load", response_class=ModelLoadResponse, dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest):
if model_container and model_container.model:
raise HTTPException(400, "A model is already loaded! Please unload it first.")
def generator():
global model_container
model_config = config["model"]
model_path = pathlib.Path(model_config["model_dir"] or "models")
model_path = model_path / data.name
model_container = ModelContainer(model_path, False, **data.model_dump())
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
else:
loading_bar.next()
yield ModelLoadResponse(
module=module,
modules=modules,
status="processing"
).model_dump_json()
yield ModelLoadResponse(
module=module,
modules=modules,
status="finished"
).model_dump_json()
return EventSourceResponse(generator())
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)])
async def unload_model():
global model_container
if model_container is None:
raise HTTPException(400, "No models are loaded.")
model_container.unload()
model_container = None
@app.post("/v1/completions", response_class=CompletionResponse, dependencies=[Depends(check_api_key)])
async def generate_completion(request: Request, data: CompletionRequest):
if data.stream:
async def generator():
@@ -44,31 +100,32 @@ async def generate_completion(request: Request, data: CompletionRequest):
return response.model_dump_json()
# Wrapper callback for load progress
def load_progress(module, modules):
yield module, modules
if __name__ == "__main__":
# Initialize auth keys
load_auth_keys()
# Load from YAML config. Possibly add a config -> kwargs conversion function
with open('config.yml', 'r') as config_file:
config = yaml.safe_load(config_file)
# If an initial model name is specified, create a container and load the model
if config["model_name"]:
model_path = f"{config['model_dir']}/{config['model_name']}" if config['model_dir'] else f"models/{config['model_name']}"
model_container = ModelContainer(model_path, False, **config)
model_config = config["model"]
if model_config["model_name"]:
model_path = pathlib.Path(model_config["model_dir"] or "models")
model_path = model_path / model_config["model_name"]
model_container = ModelContainer(model_path, False, **model_config)
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
else:
loading_bar.next()
if module == modules:
loading_bar.finish()
print("Model successfully loaded.")
uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug")
network_config = config["network"]
uvicorn.run(app, host=network_config["host"] or "127.0.0.1", port=network_config["port"] or 8012, log_level="debug")