mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-25 16:59:09 +00:00
API: Move OAI to APIRouter
This makes the API more modular for other API implementations in the future. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,8 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import pathlib
|
import pathlib
|
||||||
import uvicorn
|
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||||
from fastapi import FastAPI, Depends, HTTPException, Header, Request
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
@@ -15,7 +13,6 @@ from common.concurrency import (
|
|||||||
call_with_semaphore,
|
call_with_semaphore,
|
||||||
generate_with_semaphore,
|
generate_with_semaphore,
|
||||||
)
|
)
|
||||||
from common.logger import UVICORN_LOG_CONFIG
|
|
||||||
from common.networking import handle_request_error, run_with_request_disconnect
|
from common.networking import handle_request_error, run_with_request_disconnect
|
||||||
from common.templating import (
|
from common.templating import (
|
||||||
get_all_templates,
|
get_all_templates,
|
||||||
@@ -56,23 +53,8 @@ from endpoints.OAI.utils.completion import (
|
|||||||
from endpoints.OAI.utils.model import get_model_list, stream_model_load
|
from endpoints.OAI.utils.model import get_model_list, stream_model_load
|
||||||
from endpoints.OAI.utils.lora import get_lora_list
|
from endpoints.OAI.utils.lora import get_lora_list
|
||||||
|
|
||||||
app = FastAPI(
|
|
||||||
title="TabbyAPI",
|
|
||||||
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
|
|
||||||
description=(
|
|
||||||
"This docs page is not meant to send requests! Please use a service "
|
|
||||||
"like Postman or a frontend UI."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# ALlow CORS requests
|
router = APIRouter()
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def check_model_container():
|
async def check_model_container():
|
||||||
@@ -90,8 +72,8 @@ async def check_model_container():
|
|||||||
|
|
||||||
|
|
||||||
# Model list endpoint
|
# Model list endpoint
|
||||||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||||
async def list_models():
|
async def list_models():
|
||||||
"""Lists all models in the model directory."""
|
"""Lists all models in the model directory."""
|
||||||
model_config = config.model_config()
|
model_config = config.model_config()
|
||||||
@@ -108,7 +90,7 @@ async def list_models():
|
|||||||
|
|
||||||
|
|
||||||
# Currently loaded model endpoint
|
# Currently loaded model endpoint
|
||||||
@app.get(
|
@router.get(
|
||||||
"/v1/model",
|
"/v1/model",
|
||||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
@@ -142,7 +124,7 @@ async def get_current_model():
|
|||||||
return model_card
|
return model_card
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||||
async def list_draft_models():
|
async def list_draft_models():
|
||||||
"""Lists all draft models in the model directory."""
|
"""Lists all draft models in the model directory."""
|
||||||
draft_model_dir = unwrap(
|
draft_model_dir = unwrap(
|
||||||
@@ -156,7 +138,7 @@ async def list_draft_models():
|
|||||||
|
|
||||||
|
|
||||||
# Load model endpoint
|
# Load model endpoint
|
||||||
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||||
async def load_model(request: Request, data: ModelLoadRequest):
|
async def load_model(request: Request, data: ModelLoadRequest):
|
||||||
"""Loads a model into the model container."""
|
"""Loads a model into the model container."""
|
||||||
|
|
||||||
@@ -209,7 +191,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||||||
|
|
||||||
|
|
||||||
# Unload model endpoint
|
# Unload model endpoint
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/model/unload",
|
"/v1/model/unload",
|
||||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
@@ -218,15 +200,15 @@ async def unload_model():
|
|||||||
await model.unload_model()
|
await model.unload_model()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||||
async def get_templates():
|
async def get_templates():
|
||||||
templates = get_all_templates()
|
templates = get_all_templates()
|
||||||
template_strings = list(map(lambda template: template.stem, templates))
|
template_strings = list(map(lambda template: template.stem, templates))
|
||||||
return TemplateList(data=template_strings)
|
return TemplateList(data=template_strings)
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/template/switch",
|
"/v1/template/switch",
|
||||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
@@ -252,7 +234,7 @@ async def switch_template(data: TemplateSwitchRequest):
|
|||||||
raise HTTPException(400, error_message) from e
|
raise HTTPException(400, error_message) from e
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/template/unload",
|
"/v1/template/unload",
|
||||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
@@ -263,15 +245,15 @@ async def unload_template():
|
|||||||
|
|
||||||
|
|
||||||
# Sampler override endpoints
|
# Sampler override endpoints
|
||||||
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||||
async def list_sampler_overrides():
|
async def list_sampler_overrides():
|
||||||
"""API wrapper to list all currently applied sampler overrides"""
|
"""API wrapper to list all currently applied sampler overrides"""
|
||||||
|
|
||||||
return sampling.overrides
|
return sampling.overrides
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/sampling/override/switch",
|
"/v1/sampling/override/switch",
|
||||||
dependencies=[Depends(check_admin_key)],
|
dependencies=[Depends(check_admin_key)],
|
||||||
)
|
)
|
||||||
@@ -300,7 +282,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
|||||||
raise HTTPException(400, error_message)
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/sampling/override/unload",
|
"/v1/sampling/override/unload",
|
||||||
dependencies=[Depends(check_admin_key)],
|
dependencies=[Depends(check_admin_key)],
|
||||||
)
|
)
|
||||||
@@ -311,8 +293,8 @@ async def unload_sampler_override():
|
|||||||
|
|
||||||
|
|
||||||
# Lora list endpoint
|
# Lora list endpoint
|
||||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||||
async def get_all_loras():
|
async def get_all_loras():
|
||||||
"""Lists all LoRAs in the lora directory."""
|
"""Lists all LoRAs in the lora directory."""
|
||||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||||
@@ -322,7 +304,7 @@ async def get_all_loras():
|
|||||||
|
|
||||||
|
|
||||||
# Currently loaded loras endpoint
|
# Currently loaded loras endpoint
|
||||||
@app.get(
|
@router.get(
|
||||||
"/v1/lora",
|
"/v1/lora",
|
||||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
@@ -344,7 +326,7 @@ async def get_active_loras():
|
|||||||
|
|
||||||
|
|
||||||
# Load lora endpoint
|
# Load lora endpoint
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/lora/load",
|
"/v1/lora/load",
|
||||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
@@ -388,7 +370,7 @@ async def load_lora(data: LoraLoadRequest):
|
|||||||
|
|
||||||
|
|
||||||
# Unload lora endpoint
|
# Unload lora endpoint
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/lora/unload",
|
"/v1/lora/unload",
|
||||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
@@ -399,7 +381,7 @@ async def unload_loras():
|
|||||||
|
|
||||||
|
|
||||||
# Encode tokens endpoint
|
# Encode tokens endpoint
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/token/encode",
|
"/v1/token/encode",
|
||||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
@@ -413,7 +395,7 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||||||
|
|
||||||
|
|
||||||
# Decode tokens endpoint
|
# Decode tokens endpoint
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/token/decode",
|
"/v1/token/decode",
|
||||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
@@ -425,7 +407,7 @@ async def decode_tokens(data: TokenDecodeRequest):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
|
||||||
async def get_key_permission(
|
async def get_key_permission(
|
||||||
x_admin_key: Optional[str] = Header(None),
|
x_admin_key: Optional[str] = Header(None),
|
||||||
x_api_key: Optional[str] = Header(None),
|
x_api_key: Optional[str] = Header(None),
|
||||||
@@ -452,7 +434,7 @@ async def get_key_permission(
|
|||||||
|
|
||||||
|
|
||||||
# Completions endpoint
|
# Completions endpoint
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/completions",
|
"/v1/completions",
|
||||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
@@ -488,7 +470,7 @@ async def completion_request(request: Request, data: CompletionRequest):
|
|||||||
|
|
||||||
|
|
||||||
# Chat completions endpoint
|
# Chat completions endpoint
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
@@ -536,22 +518,3 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
|
|||||||
disconnect_message="Chat completion generation cancelled by user.",
|
disconnect_message="Chat completion generation cancelled by user.",
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def start_api(host: str, port: int):
|
|
||||||
"""Isolated function to start the API server"""
|
|
||||||
|
|
||||||
# TODO: Move OAI API to a separate folder
|
|
||||||
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
|
|
||||||
logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
|
||||||
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
|
||||||
|
|
||||||
config = uvicorn.Config(
|
|
||||||
app,
|
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
log_config=UVICORN_LOG_CONFIG,
|
|
||||||
)
|
|
||||||
server = uvicorn.Server(config)
|
|
||||||
|
|
||||||
await server.serve()
|
|
||||||
47
endpoints/server.py
Normal file
47
endpoints/server.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from common.logger import UVICORN_LOG_CONFIG
|
||||||
|
from endpoints.OAI.router import router as OAIRouter
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="TabbyAPI",
|
||||||
|
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
|
||||||
|
description=(
|
||||||
|
"This docs page is not meant to send requests! Please use a service "
|
||||||
|
"like Postman or a frontend UI."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ALlow CORS requests
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_api(host: str, port: int):
|
||||||
|
"""Isolated function to start the API server"""
|
||||||
|
|
||||||
|
# TODO: Move OAI API to a separate folder
|
||||||
|
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
|
||||||
|
logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
||||||
|
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
||||||
|
|
||||||
|
# Add OAI router
|
||||||
|
app.include_router(OAIRouter)
|
||||||
|
|
||||||
|
config = uvicorn.Config(
|
||||||
|
app,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
log_config=UVICORN_LOG_CONFIG,
|
||||||
|
)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
|
||||||
|
await server.serve()
|
||||||
2
main.py
2
main.py
@@ -15,7 +15,7 @@ from common.logger import setup_logger
|
|||||||
from common.networking import is_port_in_use
|
from common.networking import is_port_in_use
|
||||||
from common.signals import signal_handler
|
from common.signals import signal_handler
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
from endpoints.OAI.app import start_api
|
from endpoints.server import start_api
|
||||||
|
|
||||||
|
|
||||||
async def entrypoint(args: Optional[dict] = None):
|
async def entrypoint(args: Optional[dict] = None):
|
||||||
|
|||||||
Reference in New Issue
Block a user