API: Migrate universal routes to core

Place OAI specific routes in the appropriate folder. This is in
preperation for adding new API servers that can be optionally enabled.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-07-23 14:08:48 -04:00
parent 64c2cc85c9
commit 9ad69e8ab6
13 changed files with 442 additions and 427 deletions

View File

@@ -1,45 +1,18 @@
import asyncio
import pathlib
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
from common import config, model, sampling
from common.auth import check_admin_key, check_api_key, get_key_permission
from common.downloader import hf_repo_download
from common import config, model
from common.auth import check_api_key
from common.model import check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
from endpoints.OAI.types.lora import (
LoraList,
LoraLoadRequest,
LoraLoadResponse,
)
from endpoints.OAI.types.model import (
ModelCard,
ModelList,
ModelLoadRequest,
ModelLoadResponse,
)
from endpoints.OAI.types.sampler_overrides import (
SamplerOverrideListResponse,
SamplerOverrideSwitchRequest,
)
from endpoints.OAI.types.template import TemplateList, TemplateSwitchRequest
from endpoints.OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
TokenDecodeRequest,
TokenDecodeResponse,
)
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
@@ -49,13 +22,6 @@ from endpoints.OAI.utils.completion import (
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.model import (
get_current_model,
get_current_model_list,
get_model_list,
stream_model_load,
)
from endpoints.OAI.utils.lora import get_active_loras, get_lora_list
router = APIRouter()
@@ -159,392 +125,3 @@ async def chat_completion_request(
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
)
return response
# Model list endpoint
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models(request: Request) -> ModelList:
"""
Lists all models in the model directory.
Requires an admin key to see all models.
"""
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_model_dir = config.draft_model_config().get("draft_model_dir")
if get_key_permission(request) == "admin":
models = get_model_list(model_path.resolve(), draft_model_dir)
else:
models = await get_current_model_list()
if unwrap(model_config.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models
# Currently loaded model endpoint
@router.get(
"/v1/model",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def current_model() -> ModelCard:
"""Returns the currently loaded model."""
return get_current_model()
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models(request: Request) -> ModelList:
"""
Lists all draft models in the model directory.
Requires an admin key to see all draft models.
"""
if get_key_permission(request) == "admin":
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
else:
models = await get_current_model_list(is_draft=True)
return models
# Load model endpoint
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
"""Loads a model into the model container. This returns an SSE stream."""
# Verify request parameters
if not data.name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = model_path / data.name
draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
error_message = handle_request_error(
"Could not find the draft model name for model load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
draft_model_path = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
if not model_path.exists():
error_message = handle_request_error(
"Could not find the model path for load. Check model name or config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
return EventSourceResponse(
stream_model_load(data, model_path, draft_model_path), ping=maxsize
)
# Unload model endpoint
@router.post(
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_model():
"""Unloads the currently loaded model."""
await model.unload_model(skip_wait=True)
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
"""Downloads a model from HuggingFace."""
try:
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
# For now, the downloader and request data are 1:1
download_path = await run_with_request_disconnect(
request,
download_task,
"Download request cancelled by user. Files have been cleaned up.",
)
return DownloadResponse(download_path=str(download_path))
except Exception as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
# Lora list endpoint
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def list_all_loras(request: Request) -> LoraList:
"""
Lists all LoRAs in the lora directory.
Requires an admin key to see all LoRAs.
"""
if get_key_permission(request) == "admin":
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
else:
loras = get_active_loras()
return loras
# Currently loaded loras endpoint
@router.get(
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def active_loras() -> LoraList:
"""Returns the currently loaded loras."""
return get_active_loras()
# Load lora endpoint
@router.post(
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
"""Loads a LoRA into the model container."""
if not data.loras:
error_message = handle_request_error(
"List of loras to load is not found.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
if not lora_dir.exists():
error_message = handle_request_error(
"A parent lora directory does not exist for load. Check your config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
load_result = await model.load_loras(
lora_dir, **data.model_dump(), skip_wait=data.skip_queue
)
return LoraLoadResponse(
success=unwrap(load_result.get("success"), []),
failure=unwrap(load_result.get("failure"), []),
)
# Unload lora endpoint
@router.post(
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_loras():
"""Unloads the currently loaded loras."""
await model.unload_loras()
# Encode tokens endpoint
@router.post(
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
"""Encodes a string or chat completion messages into tokens."""
if isinstance(data.text, str):
text = data.text
else:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True)
)
template_vars = {
"messages": data.text,
"add_generation_prompt": False,
**special_tokens_dict,
}
text, _ = model.container.prompt_template.render(template_vars)
raw_tokens = model.container.encode_tokens(text, **data.get_params())
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
return response
# Decode tokens endpoint
@router.post(
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
"""Decodes tokens into a string."""
message = model.container.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
return response
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
async def key_permission(request: Request) -> AuthPermissionResponse:
"""
Gets the access level/permission of a provided key in headers.
Priority:
- X-admin-key
- X-api-key
- Authorization
"""
try:
permission = get_key_permission(request)
return AuthPermissionResponse(permission=permission)
except ValueError as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def list_templates(request: Request) -> TemplateList:
"""
Get a list of all templates.
Requires an admin key to see all templates.
"""
template_strings = []
if get_key_permission(request) == "admin":
templates = get_all_templates()
template_strings = [template.stem for template in templates]
else:
if model.container and model.container.prompt_template:
template_strings.append(model.container.prompt_template.name)
return TemplateList(data=template_strings)
@router.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template."""
if not data.name:
error_message = handle_request_error(
"New template name not found.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
try:
model.container.prompt_template = PromptTemplate.from_file(data.name)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
@router.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_template():
"""Unloads the currently selected template"""
model.container.prompt_template = None
# Sampler override endpoints
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse:
"""
List all currently applied sampler overrides.
Requires an admin key to see all override presets.
"""
if get_key_permission(request) == "admin":
presets = sampling.get_all_presets()
else:
presets = []
return SamplerOverrideListResponse(
presets=presets, **sampling.overrides_container.model_dump()
)
@router.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""
if data.preset:
try:
sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
error_message = handle_request_error(
f"Sampler override preset with name {data.preset} does not exist. "
+ "Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
elif data.overrides:
sampling.overrides_from_dict(data.overrides)
else:
error_message = handle_request_error(
"A sampler override preset or dictionary wasn't provided.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
@router.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
sampling.overrides_from_dict({})

View File

@@ -1,7 +0,0 @@
"""Types for auth requests."""
from pydantic import BaseModel
class AuthPermissionResponse(BaseModel):
permission: str

View File

@@ -1,25 +0,0 @@
from pydantic import BaseModel, Field
from typing import List, Optional
def _generate_include_list():
return ["*"]
class DownloadRequest(BaseModel):
"""Parameters for a HuggingFace repo download."""
repo_id: str
repo_type: str = "model"
folder_name: Optional[str] = None
revision: Optional[str] = None
token: Optional[str] = None
include: List[str] = Field(default_factory=_generate_include_list)
exclude: List[str] = Field(default_factory=list)
chunk_limit: Optional[int] = None
class DownloadResponse(BaseModel):
"""Response for a download request."""
download_path: str

View File

@@ -1,43 +0,0 @@
"""Lora types"""
from pydantic import BaseModel, Field
from time import time
from typing import Optional, List
class LoraCard(BaseModel):
"""Represents a single Lora card."""
id: str = "test"
object: str = "lora"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
scaling: Optional[float] = None
class LoraList(BaseModel):
"""Represents a list of Lora cards."""
object: str = "list"
data: List[LoraCard] = Field(default_factory=list)
class LoraLoadInfo(BaseModel):
"""Represents a single Lora load info."""
name: str
scaling: Optional[float] = 1.0
class LoraLoadRequest(BaseModel):
"""Represents a Lora load request."""
loras: List[LoraLoadInfo]
skip_queue: bool = False
class LoraLoadResponse(BaseModel):
"""Represents a Lora load response."""
success: List[str] = Field(default_factory=list)
failure: List[str] = Field(default_factory=list)

View File

@@ -1,149 +0,0 @@
"""Contains model card types."""
from pydantic import BaseModel, Field, ConfigDict
from time import time
from typing import List, Optional
from common.gen_logging import GenLogPreferences
from common.model import get_config_default
class ModelCardParameters(BaseModel):
"""Represents model card parameters."""
# Safe to do this since it's guaranteed to fetch a max seq len
# from model_container
max_seq_len: Optional[int] = None
rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0
cache_size: Optional[int] = None
cache_mode: Optional[str] = "FP16"
chunk_size: Optional[int] = 2048
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
# Draft is another model, so include it in the card params
draft: Optional["ModelCard"] = None
class ModelCard(BaseModel):
"""Represents a single model card."""
id: str = "test"
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
logging: Optional[GenLogPreferences] = None
parameters: Optional[ModelCardParameters] = None
class ModelList(BaseModel):
"""Represents a list of model cards."""
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class DraftModelLoadRequest(BaseModel):
"""Represents a draft model load request."""
# Required
draft_model_name: str
# Config arguments
draft_rope_scale: Optional[float] = Field(
default_factory=lambda: get_config_default(
"draft_rope_scale", 1.0, is_draft=True
)
)
draft_rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present",
default_factory=lambda: get_config_default(
"draft_rope_alpha", None, is_draft=True
),
examples=[1.0],
)
draft_cache_mode: Optional[str] = Field(
default_factory=lambda: get_config_default(
"draft_cache_mode", "FP16", is_draft=True
)
)
class ModelLoadRequest(BaseModel):
"""Represents a model load request."""
# Required
name: str
# Config arguments
# Max seq len is fetched from config.json of the model by default
max_seq_len: Optional[int] = Field(
description="Leave this blank to use the model's base sequence length",
default_factory=lambda: get_config_default("max_seq_len"),
examples=[4096],
)
override_base_seq_len: Optional[int] = Field(
description=(
"Overrides the model's base sequence length. " "Leave blank if unsure"
),
default_factory=lambda: get_config_default("override_base_seq_len"),
examples=[4096],
)
cache_size: Optional[int] = Field(
description=("Number in tokens, must be greater than or equal to max_seq_len"),
default_factory=lambda: get_config_default("cache_size"),
examples=[4096],
)
gpu_split_auto: Optional[bool] = Field(
default_factory=lambda: get_config_default("gpu_split_auto", True)
)
autosplit_reserve: Optional[List[float]] = Field(
default_factory=lambda: get_config_default("autosplit_reserve", [96])
)
gpu_split: Optional[List[float]] = Field(
default_factory=lambda: get_config_default("gpu_split", []),
examples=[[24.0, 20.0]],
)
rope_scale: Optional[float] = Field(
description="Automatically pulled from the model's config if not present",
default_factory=lambda: get_config_default("rope_scale"),
examples=[1.0],
)
rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present",
default_factory=lambda: get_config_default("rope_alpha"),
examples=[1.0],
)
cache_mode: Optional[str] = Field(
default_factory=lambda: get_config_default("cache_mode", "FP16")
)
chunk_size: Optional[int] = Field(
default_factory=lambda: get_config_default("chunk_size", 2048)
)
prompt_template: Optional[str] = Field(
default_factory=lambda: get_config_default("prompt_template")
)
num_experts_per_token: Optional[int] = Field(
default_factory=lambda: get_config_default("num_experts_per_token")
)
fasttensors: Optional[bool] = Field(
default_factory=lambda: get_config_default("fasttensors", False)
)
# Non-config arguments
draft: Optional[DraftModelLoadRequest] = None
skip_queue: Optional[bool] = False
class ModelLoadResponse(BaseModel):
"""Represents a model load response."""
# Avoids pydantic namespace warning
model_config = ConfigDict(protected_namespaces=[])
model_type: str = "model"
module: int
modules: int
status: str

View File

@@ -1,34 +0,0 @@
from pydantic import BaseModel, Field
from typing import List, Optional
from common.sampling import SamplerOverridesContainer
class SamplerOverrideListResponse(SamplerOverridesContainer):
"""Sampler override list response"""
presets: Optional[List[str]]
class SamplerOverrideSwitchRequest(BaseModel):
"""Sampler override switch request"""
preset: Optional[str] = Field(
default=None, description="Pass a sampler override preset name"
)
overrides: Optional[dict] = Field(
default=None,
description=(
"Sampling override parent takes in individual keys and overrides. "
+ "Ignored if preset is provided."
),
examples=[
{
"top_p": {
"override": 1.5,
"force": False,
}
}
],
)

View File

@@ -1,15 +0,0 @@
from pydantic import BaseModel, Field
from typing import List
class TemplateList(BaseModel):
"""Represents a list of templates."""
object: str = "list"
data: List[str] = Field(default_factory=list)
class TemplateSwitchRequest(BaseModel):
"""Request to switch a template."""
name: str

View File

@@ -1,51 +0,0 @@
"""Tokenization types"""
from pydantic import BaseModel
from typing import Dict, List, Union
class CommonTokenRequest(BaseModel):
"""Represents a common tokenization request."""
add_bos_token: bool = True
encode_special_tokens: bool = True
decode_special_tokens: bool = True
def get_params(self):
"""Get the parameters for tokenization."""
return {
"add_bos_token": self.add_bos_token,
"encode_special_tokens": self.encode_special_tokens,
"decode_special_tokens": self.decode_special_tokens,
}
class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""
text: Union[str, List[Dict[str, str]]]
class TokenEncodeResponse(BaseModel):
"""Represents a tokenization response."""
tokens: List[int]
length: int
class TokenDecodeRequest(CommonTokenRequest):
""" " Represents a detokenization request."""
tokens: List[int]
class TokenDecodeResponse(BaseModel):
"""Represents a detokenization response."""
text: str
class TokenCountResponse(BaseModel):
"""Represents a token count response."""
length: int

View File

@@ -1,30 +0,0 @@
import pathlib
from common import model
from endpoints.OAI.types.lora import LoraCard, LoraList
def get_lora_list(lora_path: pathlib.Path):
"""Get the list of Lora cards from the provided path."""
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id=path.name)
lora_list.data.append(lora_card)
return lora_list
def get_active_loras():
if model.container:
active_loras = [
LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
)
for lora in model.container.get_loras()
]
else:
active_loras = []
return LoraList(data=active_loras)

View File

@@ -1,123 +0,0 @@
import pathlib
from asyncio import CancelledError
from typing import Optional
from common import gen_logging, model
from common.networking import get_generator_error, handle_request_disconnect
from common.utils import unwrap
from endpoints.OAI.types.model import (
ModelCard,
ModelCardParameters,
ModelList,
ModelLoadRequest,
ModelLoadResponse,
)
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
"""Get the list of models from the provided path."""
# Convert the provided draft model path to a pathlib path for
# equality comparisons
if draft_model_path:
draft_model_path = pathlib.Path(draft_model_path).resolve()
model_card_list = ModelList()
for path in model_path.iterdir():
# Don't include the draft models path
if path.is_dir() and path != draft_model_path:
model_card = ModelCard(id=path.name)
model_card_list.data.append(model_card) # pylint: disable=no-member
return model_card_list
async def get_current_model_list(is_draft: bool = False):
"""Gets the current model in list format and with path only."""
current_models = []
# Make sure the model container exists
if model.container:
model_path = model.container.get_model_path(is_draft)
if model_path:
current_models.append(ModelCard(id=model_path.name))
return ModelList(data=current_models)
def get_current_model():
"""Gets the current model with all parameters."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=gen_logging.PREFERENCES,
)
if draft_model_params:
draft_card = ModelCard(
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
return model_card
async def stream_model_load(
data: ModelLoadRequest,
model_path: pathlib.Path,
draft_model_path: str,
):
"""Request generation wrapper for the loading process."""
# Set the draft model path if it exists
load_data = data.model_dump()
if draft_model_path:
load_data["draft"]["draft_model_dir"] = draft_model_path
load_status = model.load_model_gen(
model_path, skip_wait=data.skip_queue, **load_data
)
try:
async for module, modules, model_type in load_status:
if module != 0:
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing",
)
yield response.model_dump_json()
if module == modules:
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="finished",
)
yield response.model_dump_json()
except CancelledError:
# Get out if the request gets disconnected
handle_request_disconnect(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
except Exception as exc:
yield get_generator_error(str(exc))