mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-26 01:08:52 +00:00
API: Migrate universal routes to core
Place OAI specific routes in the appropriate folder. This is in preperation for adding new API servers that can be optionally enabled. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,45 +1,18 @@
|
||||
import asyncio
|
||||
import pathlib
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
from sys import maxsize
|
||||
|
||||
from common import config, model, sampling
|
||||
from common.auth import check_admin_key, check_api_key, get_key_permission
|
||||
from common.downloader import hf_repo_download
|
||||
from common import config, model
|
||||
from common.auth import check_api_key
|
||||
from common.model import check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.auth import AuthPermissionResponse
|
||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
|
||||
from endpoints.OAI.types.lora import (
|
||||
LoraList,
|
||||
LoraLoadRequest,
|
||||
LoraLoadResponse,
|
||||
)
|
||||
from endpoints.OAI.types.model import (
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
ModelLoadResponse,
|
||||
)
|
||||
from endpoints.OAI.types.sampler_overrides import (
|
||||
SamplerOverrideListResponse,
|
||||
SamplerOverrideSwitchRequest,
|
||||
)
|
||||
from endpoints.OAI.types.template import TemplateList, TemplateSwitchRequest
|
||||
from endpoints.OAI.types.token import (
|
||||
TokenEncodeRequest,
|
||||
TokenEncodeResponse,
|
||||
TokenDecodeRequest,
|
||||
TokenDecodeResponse,
|
||||
)
|
||||
from endpoints.OAI.utils.chat_completion import (
|
||||
format_prompt_with_template,
|
||||
generate_chat_completion,
|
||||
@@ -49,13 +22,6 @@ from endpoints.OAI.utils.completion import (
|
||||
generate_completion,
|
||||
stream_generate_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.model import (
|
||||
get_current_model,
|
||||
get_current_model_list,
|
||||
get_model_list,
|
||||
stream_model_load,
|
||||
)
|
||||
from endpoints.OAI.utils.lora import get_active_loras, get_lora_list
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
@@ -159,392 +125,3 @@ async def chat_completion_request(
|
||||
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
# Model list endpoint
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models(request: Request) -> ModelList:
|
||||
"""
|
||||
Lists all models in the model directory.
|
||||
|
||||
Requires an admin key to see all models.
|
||||
"""
|
||||
|
||||
model_config = config.model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_model_dir = config.draft_model_config().get("draft_model_dir")
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
else:
|
||||
models = await get_current_model_list()
|
||||
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||
|
||||
return models
|
||||
|
||||
|
||||
# Currently loaded model endpoint
|
||||
@router.get(
|
||||
"/v1/model",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def current_model() -> ModelCard:
|
||||
"""Returns the currently loaded model."""
|
||||
|
||||
return get_current_model()
|
||||
|
||||
|
||||
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_draft_models(request: Request) -> ModelList:
|
||||
"""
|
||||
Lists all draft models in the model directory.
|
||||
|
||||
Requires an admin key to see all draft models.
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
draft_model_dir = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
else:
|
||||
models = await get_current_model_list(is_draft=True)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
# Load model endpoint
|
||||
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||
async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
"""Loads a model into the model container. This returns an SSE stream."""
|
||||
|
||||
# Verify request parameters
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"A model name was not provided for load.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
draft_model_path = None
|
||||
if data.draft:
|
||||
if not data.draft.draft_model_name:
|
||||
error_message = handle_request_error(
|
||||
"Could not find the draft model name for model load.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
draft_model_path = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
|
||||
if not model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
"Could not find the model path for load. Check model name or config.yml?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
return EventSourceResponse(
|
||||
stream_model_load(data, model_path, draft_model_path), ping=maxsize
|
||||
)
|
||||
|
||||
|
||||
# Unload model endpoint
|
||||
@router.post(
|
||||
"/v1/model/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_model():
|
||||
"""Unloads the currently loaded model."""
|
||||
await model.unload_model(skip_wait=True)
|
||||
|
||||
|
||||
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
|
||||
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
|
||||
"""Downloads a model from HuggingFace."""
|
||||
|
||||
try:
|
||||
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
|
||||
|
||||
# For now, the downloader and request data are 1:1
|
||||
download_path = await run_with_request_disconnect(
|
||||
request,
|
||||
download_task,
|
||||
"Download request cancelled by user. Files have been cleaned up.",
|
||||
)
|
||||
|
||||
return DownloadResponse(download_path=str(download_path))
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(str(exc)).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
# Lora list endpoint
|
||||
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_all_loras(request: Request) -> LoraList:
|
||||
"""
|
||||
Lists all LoRAs in the lora directory.
|
||||
|
||||
Requires an admin key to see all LoRAs.
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
else:
|
||||
loras = get_active_loras()
|
||||
|
||||
return loras
|
||||
|
||||
|
||||
# Currently loaded loras endpoint
|
||||
@router.get(
|
||||
"/v1/lora",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def active_loras() -> LoraList:
|
||||
"""Returns the currently loaded loras."""
|
||||
|
||||
return get_active_loras()
|
||||
|
||||
|
||||
# Load lora endpoint
|
||||
@router.post(
|
||||
"/v1/lora/load",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
|
||||
"""Loads a LoRA into the model container."""
|
||||
|
||||
if not data.loras:
|
||||
error_message = handle_request_error(
|
||||
"List of loras to load is not found.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
if not lora_dir.exists():
|
||||
error_message = handle_request_error(
|
||||
"A parent lora directory does not exist for load. Check your config.yml?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
load_result = await model.load_loras(
|
||||
lora_dir, **data.model_dump(), skip_wait=data.skip_queue
|
||||
)
|
||||
|
||||
return LoraLoadResponse(
|
||||
success=unwrap(load_result.get("success"), []),
|
||||
failure=unwrap(load_result.get("failure"), []),
|
||||
)
|
||||
|
||||
|
||||
# Unload lora endpoint
|
||||
@router.post(
|
||||
"/v1/lora/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_loras():
|
||||
"""Unloads the currently loaded loras."""
|
||||
|
||||
await model.unload_loras()
|
||||
|
||||
|
||||
# Encode tokens endpoint
|
||||
@router.post(
|
||||
"/v1/token/encode",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
|
||||
"""Encodes a string or chat completion messages into tokens."""
|
||||
|
||||
if isinstance(data.text, str):
|
||||
text = data.text
|
||||
else:
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True)
|
||||
)
|
||||
|
||||
template_vars = {
|
||||
"messages": data.text,
|
||||
"add_generation_prompt": False,
|
||||
**special_tokens_dict,
|
||||
}
|
||||
|
||||
text, _ = model.container.prompt_template.render(template_vars)
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, **data.get_params())
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Decode tokens endpoint
|
||||
@router.post(
|
||||
"/v1/token/decode",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
|
||||
"""Decodes tokens into a string."""
|
||||
|
||||
message = model.container.decode_tokens(data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text=unwrap(message, ""))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
|
||||
async def key_permission(request: Request) -> AuthPermissionResponse:
|
||||
"""
|
||||
Gets the access level/permission of a provided key in headers.
|
||||
|
||||
Priority:
|
||||
- X-admin-key
|
||||
- X-api-key
|
||||
- Authorization
|
||||
"""
|
||||
|
||||
try:
|
||||
permission = get_key_permission(request)
|
||||
return AuthPermissionResponse(permission=permission)
|
||||
except ValueError as exc:
|
||||
error_message = handle_request_error(str(exc)).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_templates(request: Request) -> TemplateList:
|
||||
"""
|
||||
Get a list of all templates.
|
||||
|
||||
Requires an admin key to see all templates.
|
||||
"""
|
||||
|
||||
template_strings = []
|
||||
if get_key_permission(request) == "admin":
|
||||
templates = get_all_templates()
|
||||
template_strings = [template.stem for template in templates]
|
||||
else:
|
||||
if model.container and model.container.prompt_template:
|
||||
template_strings.append(model.container.prompt_template.name)
|
||||
|
||||
return TemplateList(data=template_strings)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/template/switch",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def switch_template(data: TemplateSwitchRequest):
|
||||
"""Switch the currently loaded template."""
|
||||
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"New template name not found.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
model.container.prompt_template = PromptTemplate.from_file(data.name)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/template/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_template():
|
||||
"""Unloads the currently selected template"""
|
||||
|
||||
model.container.prompt_template = None
|
||||
|
||||
|
||||
# Sampler override endpoints
|
||||
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse:
|
||||
"""
|
||||
List all currently applied sampler overrides.
|
||||
|
||||
Requires an admin key to see all override presets.
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
presets = sampling.get_all_presets()
|
||||
else:
|
||||
presets = []
|
||||
|
||||
return SamplerOverrideListResponse(
|
||||
presets=presets, **sampling.overrides_container.model_dump()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/sampling/override/switch",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
||||
"""Switch the currently loaded override preset"""
|
||||
|
||||
if data.preset:
|
||||
try:
|
||||
sampling.overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"Sampler override preset with name {data.preset} does not exist. "
|
||||
+ "Check the spelling?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from e
|
||||
elif data.overrides:
|
||||
sampling.overrides_from_dict(data.overrides)
|
||||
else:
|
||||
error_message = handle_request_error(
|
||||
"A sampler override preset or dictionary wasn't provided.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/sampling/override/unload",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def unload_sampler_override():
|
||||
"""Unloads the currently selected override preset"""
|
||||
|
||||
sampling.overrides_from_dict({})
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Types for auth requests."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthPermissionResponse(BaseModel):
|
||||
permission: str
|
||||
@@ -1,25 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
def _generate_include_list():
|
||||
return ["*"]
|
||||
|
||||
|
||||
class DownloadRequest(BaseModel):
|
||||
"""Parameters for a HuggingFace repo download."""
|
||||
|
||||
repo_id: str
|
||||
repo_type: str = "model"
|
||||
folder_name: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
token: Optional[str] = None
|
||||
include: List[str] = Field(default_factory=_generate_include_list)
|
||||
exclude: List[str] = Field(default_factory=list)
|
||||
chunk_limit: Optional[int] = None
|
||||
|
||||
|
||||
class DownloadResponse(BaseModel):
|
||||
"""Response for a download request."""
|
||||
|
||||
download_path: str
|
||||
@@ -1,43 +0,0 @@
|
||||
"""Lora types"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
class LoraCard(BaseModel):
|
||||
"""Represents a single Lora card."""
|
||||
|
||||
id: str = "test"
|
||||
object: str = "lora"
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
owned_by: str = "tabbyAPI"
|
||||
scaling: Optional[float] = None
|
||||
|
||||
|
||||
class LoraList(BaseModel):
|
||||
"""Represents a list of Lora cards."""
|
||||
|
||||
object: str = "list"
|
||||
data: List[LoraCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class LoraLoadInfo(BaseModel):
|
||||
"""Represents a single Lora load info."""
|
||||
|
||||
name: str
|
||||
scaling: Optional[float] = 1.0
|
||||
|
||||
|
||||
class LoraLoadRequest(BaseModel):
|
||||
"""Represents a Lora load request."""
|
||||
|
||||
loras: List[LoraLoadInfo]
|
||||
skip_queue: bool = False
|
||||
|
||||
|
||||
class LoraLoadResponse(BaseModel):
|
||||
"""Represents a Lora load response."""
|
||||
|
||||
success: List[str] = Field(default_factory=list)
|
||||
failure: List[str] = Field(default_factory=list)
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Contains model card types."""
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from time import time
|
||||
from typing import List, Optional
|
||||
|
||||
from common.gen_logging import GenLogPreferences
|
||||
from common.model import get_config_default
|
||||
|
||||
|
||||
class ModelCardParameters(BaseModel):
|
||||
"""Represents model card parameters."""
|
||||
|
||||
# Safe to do this since it's guaranteed to fetch a max seq len
|
||||
# from model_container
|
||||
max_seq_len: Optional[int] = None
|
||||
rope_scale: Optional[float] = 1.0
|
||||
rope_alpha: Optional[float] = 1.0
|
||||
cache_size: Optional[int] = None
|
||||
cache_mode: Optional[str] = "FP16"
|
||||
chunk_size: Optional[int] = 2048
|
||||
prompt_template: Optional[str] = None
|
||||
num_experts_per_token: Optional[int] = None
|
||||
|
||||
# Draft is another model, so include it in the card params
|
||||
draft: Optional["ModelCard"] = None
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
"""Represents a single model card."""
|
||||
|
||||
id: str = "test"
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
owned_by: str = "tabbyAPI"
|
||||
logging: Optional[GenLogPreferences] = None
|
||||
parameters: Optional[ModelCardParameters] = None
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
"""Represents a list of model cards."""
|
||||
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DraftModelLoadRequest(BaseModel):
|
||||
"""Represents a draft model load request."""
|
||||
|
||||
# Required
|
||||
draft_model_name: str
|
||||
|
||||
# Config arguments
|
||||
draft_rope_scale: Optional[float] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_scale", 1.0, is_draft=True
|
||||
)
|
||||
)
|
||||
draft_rope_alpha: Optional[float] = Field(
|
||||
description="Automatically calculated if not present",
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_alpha", None, is_draft=True
|
||||
),
|
||||
examples=[1.0],
|
||||
)
|
||||
draft_cache_mode: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_cache_mode", "FP16", is_draft=True
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadRequest(BaseModel):
|
||||
"""Represents a model load request."""
|
||||
|
||||
# Required
|
||||
name: str
|
||||
|
||||
# Config arguments
|
||||
|
||||
# Max seq len is fetched from config.json of the model by default
|
||||
max_seq_len: Optional[int] = Field(
|
||||
description="Leave this blank to use the model's base sequence length",
|
||||
default_factory=lambda: get_config_default("max_seq_len"),
|
||||
examples=[4096],
|
||||
)
|
||||
override_base_seq_len: Optional[int] = Field(
|
||||
description=(
|
||||
"Overrides the model's base sequence length. " "Leave blank if unsure"
|
||||
),
|
||||
default_factory=lambda: get_config_default("override_base_seq_len"),
|
||||
examples=[4096],
|
||||
)
|
||||
cache_size: Optional[int] = Field(
|
||||
description=("Number in tokens, must be greater than or equal to max_seq_len"),
|
||||
default_factory=lambda: get_config_default("cache_size"),
|
||||
examples=[4096],
|
||||
)
|
||||
gpu_split_auto: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("gpu_split_auto", True)
|
||||
)
|
||||
autosplit_reserve: Optional[List[float]] = Field(
|
||||
default_factory=lambda: get_config_default("autosplit_reserve", [96])
|
||||
)
|
||||
gpu_split: Optional[List[float]] = Field(
|
||||
default_factory=lambda: get_config_default("gpu_split", []),
|
||||
examples=[[24.0, 20.0]],
|
||||
)
|
||||
rope_scale: Optional[float] = Field(
|
||||
description="Automatically pulled from the model's config if not present",
|
||||
default_factory=lambda: get_config_default("rope_scale"),
|
||||
examples=[1.0],
|
||||
)
|
||||
rope_alpha: Optional[float] = Field(
|
||||
description="Automatically calculated if not present",
|
||||
default_factory=lambda: get_config_default("rope_alpha"),
|
||||
examples=[1.0],
|
||||
)
|
||||
cache_mode: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default("cache_mode", "FP16")
|
||||
)
|
||||
chunk_size: Optional[int] = Field(
|
||||
default_factory=lambda: get_config_default("chunk_size", 2048)
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default("prompt_template")
|
||||
)
|
||||
num_experts_per_token: Optional[int] = Field(
|
||||
default_factory=lambda: get_config_default("num_experts_per_token")
|
||||
)
|
||||
fasttensors: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("fasttensors", False)
|
||||
)
|
||||
|
||||
# Non-config arguments
|
||||
draft: Optional[DraftModelLoadRequest] = None
|
||||
skip_queue: Optional[bool] = False
|
||||
|
||||
|
||||
class ModelLoadResponse(BaseModel):
|
||||
"""Represents a model load response."""
|
||||
|
||||
# Avoids pydantic namespace warning
|
||||
model_config = ConfigDict(protected_namespaces=[])
|
||||
|
||||
model_type: str = "model"
|
||||
module: int
|
||||
modules: int
|
||||
status: str
|
||||
@@ -1,34 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
from common.sampling import SamplerOverridesContainer
|
||||
|
||||
|
||||
class SamplerOverrideListResponse(SamplerOverridesContainer):
|
||||
"""Sampler override list response"""
|
||||
|
||||
presets: Optional[List[str]]
|
||||
|
||||
|
||||
class SamplerOverrideSwitchRequest(BaseModel):
|
||||
"""Sampler override switch request"""
|
||||
|
||||
preset: Optional[str] = Field(
|
||||
default=None, description="Pass a sampler override preset name"
|
||||
)
|
||||
|
||||
overrides: Optional[dict] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Sampling override parent takes in individual keys and overrides. "
|
||||
+ "Ignored if preset is provided."
|
||||
),
|
||||
examples=[
|
||||
{
|
||||
"top_p": {
|
||||
"override": 1.5,
|
||||
"force": False,
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
|
||||
class TemplateList(BaseModel):
|
||||
"""Represents a list of templates."""
|
||||
|
||||
object: str = "list"
|
||||
data: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TemplateSwitchRequest(BaseModel):
|
||||
"""Request to switch a template."""
|
||||
|
||||
name: str
|
||||
@@ -1,51 +0,0 @@
|
||||
"""Tokenization types"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
||||
class CommonTokenRequest(BaseModel):
|
||||
"""Represents a common tokenization request."""
|
||||
|
||||
add_bos_token: bool = True
|
||||
encode_special_tokens: bool = True
|
||||
decode_special_tokens: bool = True
|
||||
|
||||
def get_params(self):
|
||||
"""Get the parameters for tokenization."""
|
||||
return {
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"encode_special_tokens": self.encode_special_tokens,
|
||||
"decode_special_tokens": self.decode_special_tokens,
|
||||
}
|
||||
|
||||
|
||||
class TokenEncodeRequest(CommonTokenRequest):
|
||||
"""Represents a tokenization request."""
|
||||
|
||||
text: Union[str, List[Dict[str, str]]]
|
||||
|
||||
|
||||
class TokenEncodeResponse(BaseModel):
|
||||
"""Represents a tokenization response."""
|
||||
|
||||
tokens: List[int]
|
||||
length: int
|
||||
|
||||
|
||||
class TokenDecodeRequest(CommonTokenRequest):
|
||||
""" " Represents a detokenization request."""
|
||||
|
||||
tokens: List[int]
|
||||
|
||||
|
||||
class TokenDecodeResponse(BaseModel):
|
||||
"""Represents a detokenization response."""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
"""Represents a token count response."""
|
||||
|
||||
length: int
|
||||
@@ -1,30 +0,0 @@
|
||||
import pathlib
|
||||
|
||||
from common import model
|
||||
from endpoints.OAI.types.lora import LoraCard, LoraList
|
||||
|
||||
|
||||
def get_lora_list(lora_path: pathlib.Path):
|
||||
"""Get the list of Lora cards from the provided path."""
|
||||
lora_list = LoraList()
|
||||
for path in lora_path.iterdir():
|
||||
if path.is_dir():
|
||||
lora_card = LoraCard(id=path.name)
|
||||
lora_list.data.append(lora_card)
|
||||
|
||||
return lora_list
|
||||
|
||||
|
||||
def get_active_loras():
|
||||
if model.container:
|
||||
active_loras = [
|
||||
LoraCard(
|
||||
id=pathlib.Path(lora.lora_path).parent.name,
|
||||
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
|
||||
)
|
||||
for lora in model.container.get_loras()
|
||||
]
|
||||
else:
|
||||
active_loras = []
|
||||
|
||||
return LoraList(data=active_loras)
|
||||
@@ -1,123 +0,0 @@
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import Optional
|
||||
|
||||
from common import gen_logging, model
|
||||
from common.networking import get_generator_error, handle_request_disconnect
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.model import (
|
||||
ModelCard,
|
||||
ModelCardParameters,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
ModelLoadResponse,
|
||||
)
|
||||
|
||||
|
||||
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
|
||||
"""Get the list of models from the provided path."""
|
||||
|
||||
# Convert the provided draft model path to a pathlib path for
|
||||
# equality comparisons
|
||||
if draft_model_path:
|
||||
draft_model_path = pathlib.Path(draft_model_path).resolve()
|
||||
|
||||
model_card_list = ModelList()
|
||||
for path in model_path.iterdir():
|
||||
# Don't include the draft models path
|
||||
if path.is_dir() and path != draft_model_path:
|
||||
model_card = ModelCard(id=path.name)
|
||||
model_card_list.data.append(model_card) # pylint: disable=no-member
|
||||
|
||||
return model_card_list
|
||||
|
||||
|
||||
async def get_current_model_list(is_draft: bool = False):
|
||||
"""Gets the current model in list format and with path only."""
|
||||
current_models = []
|
||||
|
||||
# Make sure the model container exists
|
||||
if model.container:
|
||||
model_path = model.container.get_model_path(is_draft)
|
||||
if model_path:
|
||||
current_models.append(ModelCard(id=model_path.name))
|
||||
|
||||
return ModelList(data=current_models)
|
||||
|
||||
|
||||
def get_current_model():
|
||||
"""Gets the current model with all parameters."""
|
||||
|
||||
model_params = model.container.get_model_parameters()
|
||||
draft_model_params = model_params.pop("draft", {})
|
||||
|
||||
if draft_model_params:
|
||||
model_params["draft"] = ModelCard(
|
||||
id=unwrap(draft_model_params.get("name"), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
else:
|
||||
draft_model_params = None
|
||||
|
||||
model_card = ModelCard(
|
||||
id=unwrap(model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(model_params),
|
||||
logging=gen_logging.PREFERENCES,
|
||||
)
|
||||
|
||||
if draft_model_params:
|
||||
draft_card = ModelCard(
|
||||
id=unwrap(draft_model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
|
||||
model_card.parameters.draft = draft_card
|
||||
|
||||
return model_card
|
||||
|
||||
|
||||
async def stream_model_load(
|
||||
data: ModelLoadRequest,
|
||||
model_path: pathlib.Path,
|
||||
draft_model_path: str,
|
||||
):
|
||||
"""Request generation wrapper for the loading process."""
|
||||
|
||||
# Set the draft model path if it exists
|
||||
load_data = data.model_dump()
|
||||
if draft_model_path:
|
||||
load_data["draft"]["draft_model_dir"] = draft_model_path
|
||||
|
||||
load_status = model.load_model_gen(
|
||||
model_path, skip_wait=data.skip_queue, **load_data
|
||||
)
|
||||
try:
|
||||
async for module, modules, model_type in load_status:
|
||||
if module != 0:
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="processing",
|
||||
)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
if module == modules:
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="finished",
|
||||
)
|
||||
|
||||
yield response.model_dump_json()
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
handle_request_disconnect(
|
||||
"Model load cancelled by user. "
|
||||
"Please make sure to run unload to free up resources."
|
||||
)
|
||||
except Exception as exc:
|
||||
yield get_generator_error(str(exc))
|
||||
Reference in New Issue
Block a user