mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
OAI: Add API-based model loading/unloading and auth routes
Models can be loaded and unloaded via the API. Also add authentication to use the API and for administrator tasks. Both types of authorization use different keys. Also fix the unload function to properly free all used vram. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
107
main.py
107
main.py
@@ -1,30 +1,86 @@
|
||||
import uvicorn
|
||||
import yaml
|
||||
from fastapi import FastAPI, Request
|
||||
import pathlib
|
||||
from auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends
|
||||
from model import ModelContainer
|
||||
from progress.bar import IncrementalBar
|
||||
from sse_starlette import EventSourceResponse
|
||||
from OAI.models.completions import CompletionRequest, CompletionResponse
|
||||
from OAI.models.models import ModelCard, ModelList
|
||||
from OAI.types.completions import CompletionRequest, CompletionResponse
|
||||
from OAI.types.models import ModelCard, ModelList, ModelLoadRequest, ModelLoadResponse
|
||||
from OAI.utils import create_completion_response, get_model_list
|
||||
from typing import Optional
|
||||
from utils import load_progress
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize a model container. This can be undefined at any period of time
|
||||
model_container: ModelContainer = None
|
||||
# Globally scoped variables. Undefined until initalized in main
|
||||
model_container: Optional[ModelContainer] = None
|
||||
config: Optional[dict] = None
|
||||
|
||||
@app.get("/v1/models")
|
||||
@app.get("/v1/model/list")
|
||||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models():
|
||||
models = get_model_list(model_container.get_model_path())
|
||||
model_config = config["model"]
|
||||
models = get_model_list(pathlib.Path(model_config["model_dir"] or "models"))
|
||||
|
||||
return models.model_dump_json()
|
||||
|
||||
@app.get("/v1/model")
|
||||
@app.get("/v1/model", dependencies=[Depends(check_api_key)])
|
||||
async def get_current_model():
|
||||
return ModelCard(id = model_container.get_model_path().name)
|
||||
if model_container is None or model_container.model is None:
|
||||
return HTTPException(400, "No models are loaded.")
|
||||
|
||||
@app.post("/v1/completions", response_class=CompletionResponse)
|
||||
model_card = ModelCard(id=model_container.get_model_path().name)
|
||||
return model_card.model_dump_json()
|
||||
|
||||
@app.post("/v1/model/load", response_class=ModelLoadResponse, dependencies=[Depends(check_admin_key)])
|
||||
async def load_model(data: ModelLoadRequest):
|
||||
if model_container and model_container.model:
|
||||
raise HTTPException(400, "A model is already loaded! Please unload it first.")
|
||||
|
||||
def generator():
|
||||
global model_container
|
||||
model_config = config["model"]
|
||||
model_path = pathlib.Path(model_config["model_dir"] or "models")
|
||||
model_path = model_path / data.name
|
||||
|
||||
model_container = ModelContainer(model_path, False, **data.model_dump())
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
|
||||
elif module == modules:
|
||||
loading_bar.next()
|
||||
loading_bar.finish()
|
||||
else:
|
||||
loading_bar.next()
|
||||
|
||||
yield ModelLoadResponse(
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="processing"
|
||||
).model_dump_json()
|
||||
|
||||
yield ModelLoadResponse(
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="finished"
|
||||
).model_dump_json()
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
|
||||
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)])
|
||||
async def unload_model():
|
||||
global model_container
|
||||
|
||||
if model_container is None:
|
||||
raise HTTPException(400, "No models are loaded.")
|
||||
|
||||
model_container.unload()
|
||||
model_container = None
|
||||
|
||||
@app.post("/v1/completions", response_class=CompletionResponse, dependencies=[Depends(check_api_key)])
|
||||
async def generate_completion(request: Request, data: CompletionRequest):
|
||||
if data.stream:
|
||||
async def generator():
|
||||
@@ -44,31 +100,32 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
||||
|
||||
return response.model_dump_json()
|
||||
|
||||
|
||||
# Wrapper callback for load progress
|
||||
def load_progress(module, modules):
|
||||
yield module, modules
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize auth keys
|
||||
load_auth_keys()
|
||||
|
||||
# Load from YAML config. Possibly add a config -> kwargs conversion function
|
||||
with open('config.yml', 'r') as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
|
||||
# If an initial model name is specified, create a container and load the model
|
||||
if config["model_name"]:
|
||||
model_path = f"{config['model_dir']}/{config['model_name']}" if config['model_dir'] else f"models/{config['model_name']}"
|
||||
|
||||
model_container = ModelContainer(model_path, False, **config)
|
||||
model_config = config["model"]
|
||||
if model_config["model_name"]:
|
||||
model_path = pathlib.Path(model_config["model_dir"] or "models")
|
||||
model_path = model_path / model_config["model_name"]
|
||||
|
||||
model_container = ModelContainer(model_path, False, **model_config)
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
|
||||
elif module == modules:
|
||||
loading_bar.next()
|
||||
loading_bar.finish()
|
||||
else:
|
||||
loading_bar.next()
|
||||
|
||||
if module == modules:
|
||||
loading_bar.finish()
|
||||
|
||||
|
||||
print("Model successfully loaded.")
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug")
|
||||
network_config = config["network"]
|
||||
uvicorn.run(app, host=network_config["host"] or "127.0.0.1", port=network_config["port"] or 8012, log_level="debug")
|
||||
|
||||
Reference in New Issue
Block a user