API: Add template switching and unload endpoints

Templates can be switched and unloaded without reloading the entire
model.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-01-22 23:13:52 -05:00
committed by Brian Dashore
parent 6c30f24c83
commit de0ba7214c
4 changed files with 61 additions and 25 deletions

34
main.py
View File

@@ -27,7 +27,11 @@ from common.config import (
)
from common.generators import call_with_semaphore, generate_with_semaphore
from common.sampling import get_overrides_from_file
from common.templating import get_all_templates, get_prompt_from_template
from common.templating import (
get_all_templates,
get_prompt_from_template,
get_template_from_file,
)
from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap
from common.logger import init_logger
from OAI.types.completion import CompletionRequest
@@ -39,7 +43,7 @@ from OAI.types.model import (
ModelLoadResponse,
ModelCardParameters,
)
from OAI.types.template import TemplateList
from OAI.types.template import TemplateList, TemplateSwitchRequest
from OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
@@ -258,6 +262,32 @@ async def get_templates():
return TemplateList(data=template_strings)
@app.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
raise HTTPException(400, "New template name not found.")
try:
template = get_template_from_file(data.name)
MODEL_CONTAINER.prompt_template = template
except FileNotFoundError as e:
raise HTTPException(400, "Template does not exist. Check the name?") from e
@app.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def unload_template():
"""Unloads the currently selected template"""
MODEL_CONTAINER.prompt_template = None
# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])