mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
API: Add template switching and unload endpoints
Templates can be switched and unloaded without reloading the entire model. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -7,3 +7,9 @@ class TemplateList(BaseModel):
|
||||
|
||||
object: str = "list"
|
||||
data: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TemplateSwitchRequest(BaseModel):
|
||||
"""Request to switch a template."""
|
||||
|
||||
name: str
|
||||
|
||||
@@ -163,30 +163,27 @@ class ExllamaV2Container:
|
||||
if prompt_template_name:
|
||||
logger.info("Loading prompt template with name " f"{prompt_template_name}")
|
||||
# Read the template
|
||||
self.prompt_template = get_template_from_file(prompt_template_name)
|
||||
else:
|
||||
# Then try finding the template from the tokenizer_config.json
|
||||
self.prompt_template = get_template_from_model_json(
|
||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||
"chat_template",
|
||||
"from_tokenizer_config",
|
||||
)
|
||||
try:
|
||||
self.prompt_template = get_template_from_file(prompt_template_name)
|
||||
except FileNotFoundError:
|
||||
self.prompt_template = None
|
||||
|
||||
# Try finding the chat template from the model's config.json
|
||||
# TODO: This may not even be used with huggingface models,
|
||||
# mark for removal.
|
||||
if self.prompt_template is None:
|
||||
# Then try finding the template from the tokenizer_config.json
|
||||
try:
|
||||
self.prompt_template = get_template_from_model_json(
|
||||
pathlib.Path(self.config.model_config),
|
||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||
"chat_template",
|
||||
"from_model_config",
|
||||
"from_tokenizer_config",
|
||||
)
|
||||
except FileNotFoundError:
|
||||
self.prompt_template = None
|
||||
|
||||
# If that fails, attempt fetching from model name
|
||||
if self.prompt_template is None:
|
||||
try:
|
||||
template_match = find_template_from_model(model_directory)
|
||||
if template_match:
|
||||
self.prompt_template = get_template_from_file(template_match)
|
||||
self.prompt_template = get_template_from_file(template_match)
|
||||
except (LookupError, FileNotFoundError):
|
||||
self.prompt_template = None
|
||||
|
||||
# Catch all for template lookup errors
|
||||
if self.prompt_template:
|
||||
|
||||
@@ -68,26 +68,29 @@ def find_template_from_model(model_path: pathlib.Path):
|
||||
"""Find a matching template name from a model path."""
|
||||
model_name = model_path.name
|
||||
template_files = get_all_templates()
|
||||
|
||||
for filepath in template_files:
|
||||
template_name = filepath.stem.lower()
|
||||
|
||||
# Check if the template name is present in the model name
|
||||
if template_name in model_name.lower():
|
||||
return template_name
|
||||
|
||||
return None
|
||||
else:
|
||||
raise LookupError("Could not find template from model name.")
|
||||
|
||||
|
||||
def get_template_from_file(prompt_template_name: str):
|
||||
"""Get a template from a jinja file."""
|
||||
|
||||
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
||||
if template_path.exists():
|
||||
with open(template_path, "r", encoding="utf8") as raw_template:
|
||||
return PromptTemplate(
|
||||
name=prompt_template_name, template=raw_template.read()
|
||||
)
|
||||
|
||||
return None
|
||||
else:
|
||||
# Let the user know if the template file isn't found
|
||||
raise FileNotFoundError(f'Template "{prompt_template_name}" not found.')
|
||||
|
||||
|
||||
# Get a template from a JSON file
|
||||
@@ -100,5 +103,5 @@ def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
|
||||
chat_template = model_config.get(key)
|
||||
if chat_template:
|
||||
return PromptTemplate(name=name, template=chat_template)
|
||||
|
||||
return None
|
||||
else:
|
||||
raise FileNotFoundError(f'Model JSON path "{json_path}" not found.')
|
||||
|
||||
34
main.py
34
main.py
@@ -27,7 +27,11 @@ from common.config import (
|
||||
)
|
||||
from common.generators import call_with_semaphore, generate_with_semaphore
|
||||
from common.sampling import get_overrides_from_file
|
||||
from common.templating import get_all_templates, get_prompt_from_template
|
||||
from common.templating import (
|
||||
get_all_templates,
|
||||
get_prompt_from_template,
|
||||
get_template_from_file,
|
||||
)
|
||||
from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||
from common.logger import init_logger
|
||||
from OAI.types.completion import CompletionRequest
|
||||
@@ -39,7 +43,7 @@ from OAI.types.model import (
|
||||
ModelLoadResponse,
|
||||
ModelCardParameters,
|
||||
)
|
||||
from OAI.types.template import TemplateList
|
||||
from OAI.types.template import TemplateList, TemplateSwitchRequest
|
||||
from OAI.types.token import (
|
||||
TokenEncodeRequest,
|
||||
TokenEncodeResponse,
|
||||
@@ -258,6 +262,32 @@ async def get_templates():
|
||||
return TemplateList(data=template_strings)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/template/switch",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
)
|
||||
async def switch_template(data: TemplateSwitchRequest):
|
||||
"""Switch the currently loaded template"""
|
||||
if not data.name:
|
||||
raise HTTPException(400, "New template name not found.")
|
||||
|
||||
try:
|
||||
template = get_template_from_file(data.name)
|
||||
MODEL_CONTAINER.prompt_template = template
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(400, "Template does not exist. Check the name?") from e
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/template/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
)
|
||||
async def unload_template():
|
||||
"""Unloads the currently selected template"""
|
||||
|
||||
MODEL_CONTAINER.prompt_template = None
|
||||
|
||||
|
||||
# Lora list endpoint
|
||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
|
||||
Reference in New Issue
Block a user