OAI: Add API-based model loading/unloading and auth routes

Models can be loaded and unloaded via the API. Also add authentication
to use the API and for administrator tasks.

Both types of authorization use different keys.

Also fix the unload function to properly free all used vram.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-14 01:17:19 -05:00
parent 47343e2f1a
commit b625bface9
11 changed files with 195 additions and 55 deletions

View File

@@ -1,13 +0,0 @@
from pydantic import BaseModel, Field
from time import time
from typing import List
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)

View File

@@ -2,7 +2,7 @@ from uuid import uuid4
from time import time from time import time
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Union from typing import List, Optional, Dict, Union
from OAI.models.common import LogProbs, UsageStats from OAI.types.common import LogProbs, UsageStats
class CompletionRespChoice(BaseModel): class CompletionRespChoice(BaseModel):
finish_reason: str finish_reason: str

27
OAI/types/models.py Normal file
View File

@@ -0,0 +1,27 @@
from pydantic import BaseModel, Field
from time import time
from typing import List, Optional
class ModelCard(BaseModel):
id: str = "test"
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class ModelLoadRequest(BaseModel):
name: str
max_seq_len: Optional[int] = 4096
gpu_split: Optional[str] = "auto"
rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0
no_flash_attention: Optional[bool] = False
low_mem: Optional[bool] = False
class ModelLoadResponse(BaseModel):
module: int
modules: int
status: str

View File

@@ -1,7 +1,7 @@
import pathlib import pathlib
from OAI.models.completions import CompletionResponse, CompletionRespChoice from OAI.types.completions import CompletionResponse, CompletionRespChoice
from OAI.models.common import UsageStats from OAI.types.common import UsageStats
from OAI.models.models import ModelList, ModelCard from OAI.types.models import ModelList, ModelCard
from typing import Optional from typing import Optional
def create_completion_response(text: str, index: int, model_name: Optional[str]): def create_completion_response(text: str, index: int, model_name: Optional[str]):

3
api_tokens.yml Normal file
View File

@@ -0,0 +1,3 @@
!!python/object:auth.AuthKeys
admin_key: 5b9e30a4197557dcd6cf48445ee174dc
api_key: 2261702e8a220c6c4671a264cd1236ce

54
auth.py Normal file
View File

@@ -0,0 +1,54 @@
import secrets
import yaml
from fastapi import Header, HTTPException
from typing import Optional
"""
This method of authorization is pretty insecure, but since TabbyAPI is a local
application, it should be fine.
"""
class AuthKeys:
api_key: str
admin_key: str
def __init__(self, api_key: str, admin_key: str):
self.api_key = api_key
self.admin_key = admin_key
auth_keys: Optional[AuthKeys] = None
def load_auth_keys():
global auth_keys
try:
with open("api_tokens.yml", "r") as auth_file:
auth_keys = yaml.safe_load(auth_file)
except:
new_auth_keys = AuthKeys(
api_key = secrets.token_hex(16),
admin_key = secrets.token_hex(16)
)
auth_keys = new_auth_keys
with open("api_tokens.yml", "w") as auth_file:
yaml.dump(auth_keys, auth_file)
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
if x_api_key and x_api_key == auth_keys.api_key:
return x_api_key
elif authorization:
split_key = authorization.split(" ")
if split_key[0].lower() == "bearer" and split_key[1] == auth_keys.api_key:
return authorization
else:
raise HTTPException(401, "Invalid API key")
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
if x_admin_key and x_admin_key == auth_keys.admin_key:
return x_admin_key
elif authorization:
split_key = authorization.split(" ")
if split_key[0].lower() == "bearer" and split_key[1] == auth_keys.admin_key:
return authorization
else:
raise HTTPException(401, "Invalid admin key")

View File

@@ -1,8 +1,14 @@
model_dir: "D:/models" # Network options
model_name: "this_is_a_exl2_model" network:
max_seq_len: 4096 host: "0.0.0.0"
gpu_split: "auto" port: 8012
rope_scale: 1.0 # Only used if you want to initially load a model
rope_alpha: 1.0 model:
no_flash_attention: False model_dir: "D:/models"
low_mem: False model_name: "airoboros-mistral2.2-7b-exl2"
max_seq_len: 4096
gpu_split: "auto"
rope_scale: 1.0
rope_alpha: 1.0
no_flash_attention: False
low_mem: False

103
main.py
View File

@@ -1,30 +1,86 @@
import uvicorn import uvicorn
import yaml import yaml
from fastapi import FastAPI, Request import pathlib
from auth import check_admin_key, check_api_key, load_auth_keys
from fastapi import FastAPI, Request, HTTPException, Depends
from model import ModelContainer from model import ModelContainer
from progress.bar import IncrementalBar from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from OAI.models.completions import CompletionRequest, CompletionResponse from OAI.types.completions import CompletionRequest, CompletionResponse
from OAI.models.models import ModelCard, ModelList from OAI.types.models import ModelCard, ModelList, ModelLoadRequest, ModelLoadResponse
from OAI.utils import create_completion_response, get_model_list from OAI.utils import create_completion_response, get_model_list
from typing import Optional
from utils import load_progress
app = FastAPI() app = FastAPI()
# Initialize a model container. This can be undefined at any period of time # Globally scoped variables. Undefined until initalized in main
model_container: ModelContainer = None model_container: Optional[ModelContainer] = None
config: Optional[dict] = None
@app.get("/v1/models") @app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list") @app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models(): async def list_models():
models = get_model_list(model_container.get_model_path()) model_config = config["model"]
models = get_model_list(pathlib.Path(model_config["model_dir"] or "models"))
return models.model_dump_json() return models.model_dump_json()
@app.get("/v1/model") @app.get("/v1/model", dependencies=[Depends(check_api_key)])
async def get_current_model(): async def get_current_model():
return ModelCard(id = model_container.get_model_path().name) if model_container is None or model_container.model is None:
return HTTPException(400, "No models are loaded.")
@app.post("/v1/completions", response_class=CompletionResponse) model_card = ModelCard(id=model_container.get_model_path().name)
return model_card.model_dump_json()
@app.post("/v1/model/load", response_class=ModelLoadResponse, dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest):
if model_container and model_container.model:
raise HTTPException(400, "A model is already loaded! Please unload it first.")
def generator():
global model_container
model_config = config["model"]
model_path = pathlib.Path(model_config["model_dir"] or "models")
model_path = model_path / data.name
model_container = ModelContainer(model_path, False, **data.model_dump())
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
else:
loading_bar.next()
yield ModelLoadResponse(
module=module,
modules=modules,
status="processing"
).model_dump_json()
yield ModelLoadResponse(
module=module,
modules=modules,
status="finished"
).model_dump_json()
return EventSourceResponse(generator())
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)])
async def unload_model():
global model_container
if model_container is None:
raise HTTPException(400, "No models are loaded.")
model_container.unload()
model_container = None
@app.post("/v1/completions", response_class=CompletionResponse, dependencies=[Depends(check_api_key)])
async def generate_completion(request: Request, data: CompletionRequest): async def generate_completion(request: Request, data: CompletionRequest):
if data.stream: if data.stream:
async def generator(): async def generator():
@@ -44,31 +100,32 @@ async def generate_completion(request: Request, data: CompletionRequest):
return response.model_dump_json() return response.model_dump_json()
# Wrapper callback for load progress
def load_progress(module, modules):
yield module, modules
if __name__ == "__main__": if __name__ == "__main__":
# Initialize auth keys
load_auth_keys()
# Load from YAML config. Possibly add a config -> kwargs conversion function # Load from YAML config. Possibly add a config -> kwargs conversion function
with open('config.yml', 'r') as config_file: with open('config.yml', 'r') as config_file:
config = yaml.safe_load(config_file) config = yaml.safe_load(config_file)
# If an initial model name is specified, create a container and load the model # If an initial model name is specified, create a container and load the model
if config["model_name"]: model_config = config["model"]
model_path = f"{config['model_dir']}/{config['model_name']}" if config['model_dir'] else f"models/{config['model_name']}" if model_config["model_name"]:
model_path = pathlib.Path(model_config["model_dir"] or "models")
model_path = model_path / model_config["model_name"]
model_container = ModelContainer(model_path, False, **config) model_container = ModelContainer(model_path, False, **model_config)
load_status = model_container.load_gen(load_progress) load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status: for (module, modules) in load_status:
if module == 0: if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
else: else:
loading_bar.next() loading_bar.next()
if module == modules:
loading_bar.finish()
print("Model successfully loaded.") print("Model successfully loaded.")
uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug") network_config = config["network"]
uvicorn.run(app, host=network_config["host"] or "127.0.0.1", port=network_config["port"] or 8012, log_level="debug")

View File

@@ -32,7 +32,7 @@ class ModelContainer:
gpu_split_auto: bool = True gpu_split_auto: bool = True
gpu_split: list or None = None gpu_split: list or None = None
def __init__(self, model_directory: str, quiet = False, **kwargs): def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs):
""" """
Create model container Create model container
@@ -62,11 +62,11 @@ class ModelContainer:
self.quiet = quiet self.quiet = quiet
self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8" self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
self.gpu_split_auto = kwargs.get("gpu_split_auto", True)
self.gpu_split = kwargs.get("gpu_split", None) self.gpu_split = kwargs.get("gpu_split", None)
self.gpu_split_auto = self.gpu_split == "auto"
self.config = ExLlamaV2Config() self.config = ExLlamaV2Config()
self.config.model_dir = model_directory self.config.model_dir = str(model_directory.resolve())
self.config.prepare() self.config.prepare()
if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"] if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"]
@@ -85,7 +85,7 @@ class ModelContainer:
if self.draft_enabled: if self.draft_enabled:
self.draft_config = ExLlamaV2Config() self.draft_config = ExLlamaV2Config()
self.draft_config.model_dir = kwargs["draft_model_directory"] self.draft_config.model_dir = kwargs["draft_model_dir"]
self.draft_config.prepare() self.draft_config.prepare()
self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.max_seq_len = self.config.max_seq_len
@@ -103,7 +103,7 @@ class ModelContainer:
def get_model_path(self): def get_model_path(self):
model_path = pathlib.Path(self.draft_config.model_dir if self.draft_enabled else self.config.model_dir) model_path = pathlib.Path(self.config.model_dir)
return model_path return model_path
@@ -185,9 +185,12 @@ class ModelContainer:
if self.model: self.model.unload() if self.model: self.model.unload()
self.model = None self.model = None
if self.draft_model: self.draft_model.unload()
self.draft_model = None
self.config = None self.config = None
self.cache = None self.cache = None
self.tokenizer = None self.tokenizer = None
self.generator = None
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()

3
utils.py Normal file
View File

@@ -0,0 +1,3 @@
# Wrapper callback for load progress
def load_progress(module, modules):
yield module, modules