API: Add template switching and unload endpoints

Templates can be switched and unloaded without reloading the entire
model.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-01-22 23:13:52 -05:00
committed by Brian Dashore
parent 6c30f24c83
commit de0ba7214c
4 changed files with 61 additions and 25 deletions

View File

@@ -7,3 +7,9 @@ class TemplateList(BaseModel):
object: str = "list"
data: List[str] = Field(default_factory=list)
class TemplateSwitchRequest(BaseModel):
"""Request to switch a template."""
name: str

View File

@@ -163,30 +163,27 @@ class ExllamaV2Container:
if prompt_template_name:
logger.info("Loading prompt template with name " f"{prompt_template_name}")
# Read the template
self.prompt_template = get_template_from_file(prompt_template_name)
else:
# Then try finding the template from the tokenizer_config.json
self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_tokenizer_config",
)
try:
self.prompt_template = get_template_from_file(prompt_template_name)
except FileNotFoundError:
self.prompt_template = None
# Try finding the chat template from the model's config.json
# TODO: This may not even be used with huggingface models,
# mark for removal.
if self.prompt_template is None:
# Then try finding the template from the tokenizer_config.json
try:
self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_config),
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_model_config",
"from_tokenizer_config",
)
except FileNotFoundError:
self.prompt_template = None
# If that fails, attempt fetching from model name
if self.prompt_template is None:
try:
template_match = find_template_from_model(model_directory)
if template_match:
self.prompt_template = get_template_from_file(template_match)
self.prompt_template = get_template_from_file(template_match)
except (LookupError, FileNotFoundError):
self.prompt_template = None
# Catch all for template lookup errors
if self.prompt_template:

View File

@@ -68,26 +68,29 @@ def find_template_from_model(model_path: pathlib.Path):
"""Find a matching template name from a model path."""
model_name = model_path.name
template_files = get_all_templates()
for filepath in template_files:
template_name = filepath.stem.lower()
# Check if the template name is present in the model name
if template_name in model_name.lower():
return template_name
return None
else:
raise LookupError("Could not find template from model name.")
def get_template_from_file(prompt_template_name: str):
"""Get a template from a jinja file."""
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template:
return PromptTemplate(
name=prompt_template_name, template=raw_template.read()
)
return None
else:
# Let the user know if the template file isn't found
raise FileNotFoundError(f'Template "{prompt_template_name}" not found.')
# Get a template from a JSON file
@@ -100,5 +103,5 @@ def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
chat_template = model_config.get(key)
if chat_template:
return PromptTemplate(name=name, template=chat_template)
return None
else:
raise FileNotFoundError(f'Model JSON path "{json_path}" not found.')

34
main.py
View File

@@ -27,7 +27,11 @@ from common.config import (
)
from common.generators import call_with_semaphore, generate_with_semaphore
from common.sampling import get_overrides_from_file
from common.templating import get_all_templates, get_prompt_from_template
from common.templating import (
get_all_templates,
get_prompt_from_template,
get_template_from_file,
)
from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap
from common.logger import init_logger
from OAI.types.completion import CompletionRequest
@@ -39,7 +43,7 @@ from OAI.types.model import (
ModelLoadResponse,
ModelCardParameters,
)
from OAI.types.template import TemplateList
from OAI.types.template import TemplateList, TemplateSwitchRequest
from OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
@@ -258,6 +262,32 @@ async def get_templates():
return TemplateList(data=template_strings)
@app.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
raise HTTPException(400, "New template name not found.")
try:
template = get_template_from_file(data.name)
MODEL_CONTAINER.prompt_template = template
except FileNotFoundError as e:
raise HTTPException(400, "Template does not exist. Check the name?") from e
@app.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def unload_template():
"""Unloads the currently selected template"""
MODEL_CONTAINER.prompt_template = None
# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])