OAI: Reorder functions

Reordering routes changes the order of appearance on documentation.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-07-08 15:27:08 -04:00
parent 521d21b9f2
commit 5c293499bd

View File

@@ -72,6 +72,104 @@ async def check_model_container():
raise HTTPException(400, error_message)
# Completions endpoint
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def completion_request(
request: Request, data: CompletionRequest
) -> CompletionResponse:
"""
Generates a completion from a prompt.
If stream = true, this returns an SSE stream.
"""
model_path = model.container.get_model_path()
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_completion(data, request, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(generate_completion(data, model_path))
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Completion generation cancelled by user.",
)
return response
# Chat completions endpoint
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
) -> ChatCompletionResponse:
"""
Generates a chat completion from a prompt.
If stream = true, this returns an SSE stream.
"""
if model.container.prompt_template is None:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
model_path = model.container.get_model_path()
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = format_prompt_with_template(data)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
)
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(prompt, data, request, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Chat completion generation cancelled by user.",
)
return response
# Model list endpoint
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
@@ -192,99 +290,6 @@ async def unload_model():
await model.unload_model(skip_wait=True)
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def get_templates() -> TemplateList:
templates = get_all_templates()
template_strings = [template.stem for template in templates]
return TemplateList(data=template_strings)
@router.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
error_message = handle_request_error(
"New template name not found.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
try:
model.container.prompt_template = PromptTemplate.from_file(data.name)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
@router.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_template():
"""Unloads the currently selected template"""
model.container.prompt_template = None
# Sampler override endpoints
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides() -> SamplerOverrideListResponse:
"""API wrapper to list all currently applied sampler overrides"""
return SamplerOverrideListResponse(
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
)
@router.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""
if data.preset:
try:
sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
error_message = handle_request_error(
f"Sampler override preset with name {data.preset} does not exist. "
+ "Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
elif data.overrides:
sampling.overrides_from_dict(data.overrides)
else:
error_message = handle_request_error(
"A sampler override preset or dictionary wasn't provided.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
@router.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
sampling.overrides_from_dict({})
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
"""Downloads a model from HuggingFace."""
@@ -452,99 +457,94 @@ async def get_key_permission(
raise HTTPException(400, error_message) from exc
# Completions endpoint
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def get_templates() -> TemplateList:
templates = get_all_templates()
template_strings = [template.stem for template in templates]
return TemplateList(data=template_strings)
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def completion_request(
request: Request, data: CompletionRequest
) -> CompletionResponse:
"""
Generates a completion from a prompt.
If stream = true, this returns an SSE stream.
"""
model_path = model.container.get_model_path()
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_completion(data, request, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(generate_completion(data, model_path))
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Completion generation cancelled by user.",
)
return response
# Chat completions endpoint
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
) -> ChatCompletionResponse:
"""
Generates a chat completion from a prompt.
If stream = true, this returns an SSE stream.
"""
if model.container.prompt_template is None:
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",
"New template name not found.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
raise HTTPException(400, error_message)
model_path = model.container.get_model_path()
try:
model.container.prompt_template = PromptTemplate.from_file(data.name)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
exc_info=False,
).error.message
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = format_prompt_with_template(data)
raise HTTPException(400, error_message) from e
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
@router.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_template():
"""Unloads the currently selected template"""
model.container.prompt_template = None
# Sampler override endpoints
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides() -> SamplerOverrideListResponse:
"""API wrapper to list all currently applied sampler overrides"""
return SamplerOverrideListResponse(
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
)
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(prompt, data, request, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Chat completion generation cancelled by user.",
)
return response
@router.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""
if data.preset:
try:
sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
error_message = handle_request_error(
f"Sampler override preset with name {data.preset} does not exist. "
+ "Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
elif data.overrides:
sampling.overrides_from_dict(data.overrides)
else:
error_message = handle_request_error(
"A sampler override preset or dictionary wasn't provided.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
@router.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
sampling.overrides_from_dict({})