mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
OAI: Reorder functions
Reordering routes changes the order of appearance on documentation. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -72,6 +72,104 @@ async def check_model_container():
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
# Completions endpoint
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def completion_request(
|
||||
request: Request, data: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
"""
|
||||
Generates a completion from a prompt.
|
||||
|
||||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
model_path = model.container.get_model_path()
|
||||
|
||||
if isinstance(data.prompt, list):
|
||||
data.prompt = "\n".join(data.prompt)
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
# Set an empty JSON schema if the request wants a JSON response
|
||||
if data.response_format.type == "json":
|
||||
data.json_schema = {"type": "object"}
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
return EventSourceResponse(
|
||||
stream_generate_completion(data, request, model_path),
|
||||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(generate_completion(data, model_path))
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
generate_task,
|
||||
disconnect_message="Completion generation cancelled by user.",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
# Chat completions endpoint
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def chat_completion_request(
|
||||
request: Request, data: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Generates a chat completion from a prompt.
|
||||
|
||||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
if model.container.prompt_template is None:
|
||||
error_message = handle_request_error(
|
||||
"Chat completions are disabled because a prompt template is not set.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
model_path = model.container.get_model_path()
|
||||
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
prompt = format_prompt_with_template(data)
|
||||
|
||||
# Set an empty JSON schema if the request wants a JSON response
|
||||
if data.response_format.type == "json":
|
||||
data.json_schema = {"type": "object"}
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
return EventSourceResponse(
|
||||
stream_generate_chat_completion(prompt, data, request, model_path),
|
||||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(
|
||||
generate_chat_completion(prompt, data, model_path)
|
||||
)
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
generate_task,
|
||||
disconnect_message="Chat completion generation cancelled by user.",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
# Model list endpoint
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
@@ -192,99 +290,6 @@ async def unload_model():
|
||||
await model.unload_model(skip_wait=True)
|
||||
|
||||
|
||||
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_templates() -> TemplateList:
|
||||
templates = get_all_templates()
|
||||
template_strings = [template.stem for template in templates]
|
||||
return TemplateList(data=template_strings)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/template/switch",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def switch_template(data: TemplateSwitchRequest):
|
||||
"""Switch the currently loaded template"""
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"New template name not found.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
model.container.prompt_template = PromptTemplate.from_file(data.name)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/template/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_template():
|
||||
"""Unloads the currently selected template"""
|
||||
|
||||
model.container.prompt_template = None
|
||||
|
||||
|
||||
# Sampler override endpoints
|
||||
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_sampler_overrides() -> SamplerOverrideListResponse:
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
return SamplerOverrideListResponse(
|
||||
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/sampling/override/switch",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
||||
"""Switch the currently loaded override preset"""
|
||||
|
||||
if data.preset:
|
||||
try:
|
||||
sampling.overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"Sampler override preset with name {data.preset} does not exist. "
|
||||
+ "Check the spelling?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from e
|
||||
elif data.overrides:
|
||||
sampling.overrides_from_dict(data.overrides)
|
||||
else:
|
||||
error_message = handle_request_error(
|
||||
"A sampler override preset or dictionary wasn't provided.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/sampling/override/unload",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def unload_sampler_override():
|
||||
"""Unloads the currently selected override preset"""
|
||||
|
||||
sampling.overrides_from_dict({})
|
||||
|
||||
|
||||
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
|
||||
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
|
||||
"""Downloads a model from HuggingFace."""
|
||||
@@ -452,99 +457,94 @@ async def get_key_permission(
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
# Completions endpoint
|
||||
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_templates() -> TemplateList:
|
||||
templates = get_all_templates()
|
||||
template_strings = [template.stem for template in templates]
|
||||
return TemplateList(data=template_strings)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
"/v1/template/switch",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def completion_request(
|
||||
request: Request, data: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
"""
|
||||
Generates a completion from a prompt.
|
||||
|
||||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
model_path = model.container.get_model_path()
|
||||
|
||||
if isinstance(data.prompt, list):
|
||||
data.prompt = "\n".join(data.prompt)
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
# Set an empty JSON schema if the request wants a JSON response
|
||||
if data.response_format.type == "json":
|
||||
data.json_schema = {"type": "object"}
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
return EventSourceResponse(
|
||||
stream_generate_completion(data, request, model_path),
|
||||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(generate_completion(data, model_path))
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
generate_task,
|
||||
disconnect_message="Completion generation cancelled by user.",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
# Chat completions endpoint
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def chat_completion_request(
|
||||
request: Request, data: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Generates a chat completion from a prompt.
|
||||
|
||||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
if model.container.prompt_template is None:
|
||||
async def switch_template(data: TemplateSwitchRequest):
|
||||
"""Switch the currently loaded template"""
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"Chat completions are disabled because a prompt template is not set.",
|
||||
"New template name not found.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(422, error_message)
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
model_path = model.container.get_model_path()
|
||||
try:
|
||||
model.container.prompt_template = PromptTemplate.from_file(data.name)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
prompt = format_prompt_with_template(data)
|
||||
raise HTTPException(400, error_message) from e
|
||||
|
||||
# Set an empty JSON schema if the request wants a JSON response
|
||||
if data.response_format.type == "json":
|
||||
data.json_schema = {"type": "object"}
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
@router.post(
|
||||
"/v1/template/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_template():
|
||||
"""Unloads the currently selected template"""
|
||||
|
||||
model.container.prompt_template = None
|
||||
|
||||
|
||||
# Sampler override endpoints
|
||||
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_sampler_overrides() -> SamplerOverrideListResponse:
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
return SamplerOverrideListResponse(
|
||||
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
|
||||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
return EventSourceResponse(
|
||||
stream_generate_chat_completion(prompt, data, request, model_path),
|
||||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(
|
||||
generate_chat_completion(prompt, data, model_path)
|
||||
)
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
generate_task,
|
||||
disconnect_message="Chat completion generation cancelled by user.",
|
||||
)
|
||||
return response
|
||||
@router.post(
|
||||
"/v1/sampling/override/switch",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
||||
"""Switch the currently loaded override preset"""
|
||||
|
||||
if data.preset:
|
||||
try:
|
||||
sampling.overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"Sampler override preset with name {data.preset} does not exist. "
|
||||
+ "Check the spelling?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from e
|
||||
elif data.overrides:
|
||||
sampling.overrides_from_dict(data.overrides)
|
||||
else:
|
||||
error_message = handle_request_error(
|
||||
"A sampler override preset or dictionary wasn't provided.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/sampling/override/unload",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def unload_sampler_override():
|
||||
"""Unloads the currently selected override preset"""
|
||||
|
||||
sampling.overrides_from_dict({})
|
||||
|
||||
Reference in New Issue
Block a user