mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 10:11:39 +00:00
API: Add template switching and unload endpoints
Templates can be switched and unloaded without reloading the entire model. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -7,3 +7,9 @@ class TemplateList(BaseModel):
|
|||||||
|
|
||||||
object: str = "list"
|
object: str = "list"
|
||||||
data: List[str] = Field(default_factory=list)
|
data: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateSwitchRequest(BaseModel):
|
||||||
|
"""Request to switch a template."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|||||||
@@ -163,30 +163,27 @@ class ExllamaV2Container:
|
|||||||
if prompt_template_name:
|
if prompt_template_name:
|
||||||
logger.info("Loading prompt template with name " f"{prompt_template_name}")
|
logger.info("Loading prompt template with name " f"{prompt_template_name}")
|
||||||
# Read the template
|
# Read the template
|
||||||
self.prompt_template = get_template_from_file(prompt_template_name)
|
try:
|
||||||
else:
|
self.prompt_template = get_template_from_file(prompt_template_name)
|
||||||
# Then try finding the template from the tokenizer_config.json
|
except FileNotFoundError:
|
||||||
self.prompt_template = get_template_from_model_json(
|
self.prompt_template = None
|
||||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
|
||||||
"chat_template",
|
|
||||||
"from_tokenizer_config",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try finding the chat template from the model's config.json
|
# Then try finding the template from the tokenizer_config.json
|
||||||
# TODO: This may not even be used with huggingface models,
|
try:
|
||||||
# mark for removal.
|
|
||||||
if self.prompt_template is None:
|
|
||||||
self.prompt_template = get_template_from_model_json(
|
self.prompt_template = get_template_from_model_json(
|
||||||
pathlib.Path(self.config.model_config),
|
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||||
"chat_template",
|
"chat_template",
|
||||||
"from_model_config",
|
"from_tokenizer_config",
|
||||||
)
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
self.prompt_template = None
|
||||||
|
|
||||||
# If that fails, attempt fetching from model name
|
# If that fails, attempt fetching from model name
|
||||||
if self.prompt_template is None:
|
try:
|
||||||
template_match = find_template_from_model(model_directory)
|
template_match = find_template_from_model(model_directory)
|
||||||
if template_match:
|
self.prompt_template = get_template_from_file(template_match)
|
||||||
self.prompt_template = get_template_from_file(template_match)
|
except (LookupError, FileNotFoundError):
|
||||||
|
self.prompt_template = None
|
||||||
|
|
||||||
# Catch all for template lookup errors
|
# Catch all for template lookup errors
|
||||||
if self.prompt_template:
|
if self.prompt_template:
|
||||||
|
|||||||
@@ -68,26 +68,29 @@ def find_template_from_model(model_path: pathlib.Path):
|
|||||||
"""Find a matching template name from a model path."""
|
"""Find a matching template name from a model path."""
|
||||||
model_name = model_path.name
|
model_name = model_path.name
|
||||||
template_files = get_all_templates()
|
template_files = get_all_templates()
|
||||||
|
|
||||||
for filepath in template_files:
|
for filepath in template_files:
|
||||||
template_name = filepath.stem.lower()
|
template_name = filepath.stem.lower()
|
||||||
|
|
||||||
# Check if the template name is present in the model name
|
# Check if the template name is present in the model name
|
||||||
if template_name in model_name.lower():
|
if template_name in model_name.lower():
|
||||||
return template_name
|
return template_name
|
||||||
|
else:
|
||||||
return None
|
raise LookupError("Could not find template from model name.")
|
||||||
|
|
||||||
|
|
||||||
def get_template_from_file(prompt_template_name: str):
|
def get_template_from_file(prompt_template_name: str):
|
||||||
"""Get a template from a jinja file."""
|
"""Get a template from a jinja file."""
|
||||||
|
|
||||||
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
||||||
if template_path.exists():
|
if template_path.exists():
|
||||||
with open(template_path, "r", encoding="utf8") as raw_template:
|
with open(template_path, "r", encoding="utf8") as raw_template:
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
name=prompt_template_name, template=raw_template.read()
|
name=prompt_template_name, template=raw_template.read()
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
return None
|
# Let the user know if the template file isn't found
|
||||||
|
raise FileNotFoundError(f'Template "{prompt_template_name}" not found.')
|
||||||
|
|
||||||
|
|
||||||
# Get a template from a JSON file
|
# Get a template from a JSON file
|
||||||
@@ -100,5 +103,5 @@ def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
|
|||||||
chat_template = model_config.get(key)
|
chat_template = model_config.get(key)
|
||||||
if chat_template:
|
if chat_template:
|
||||||
return PromptTemplate(name=name, template=chat_template)
|
return PromptTemplate(name=name, template=chat_template)
|
||||||
|
else:
|
||||||
return None
|
raise FileNotFoundError(f'Model JSON path "{json_path}" not found.')
|
||||||
|
|||||||
34
main.py
34
main.py
@@ -27,7 +27,11 @@ from common.config import (
|
|||||||
)
|
)
|
||||||
from common.generators import call_with_semaphore, generate_with_semaphore
|
from common.generators import call_with_semaphore, generate_with_semaphore
|
||||||
from common.sampling import get_overrides_from_file
|
from common.sampling import get_overrides_from_file
|
||||||
from common.templating import get_all_templates, get_prompt_from_template
|
from common.templating import (
|
||||||
|
get_all_templates,
|
||||||
|
get_prompt_from_template,
|
||||||
|
get_template_from_file,
|
||||||
|
)
|
||||||
from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||||
from common.logger import init_logger
|
from common.logger import init_logger
|
||||||
from OAI.types.completion import CompletionRequest
|
from OAI.types.completion import CompletionRequest
|
||||||
@@ -39,7 +43,7 @@ from OAI.types.model import (
|
|||||||
ModelLoadResponse,
|
ModelLoadResponse,
|
||||||
ModelCardParameters,
|
ModelCardParameters,
|
||||||
)
|
)
|
||||||
from OAI.types.template import TemplateList
|
from OAI.types.template import TemplateList, TemplateSwitchRequest
|
||||||
from OAI.types.token import (
|
from OAI.types.token import (
|
||||||
TokenEncodeRequest,
|
TokenEncodeRequest,
|
||||||
TokenEncodeResponse,
|
TokenEncodeResponse,
|
||||||
@@ -258,6 +262,32 @@ async def get_templates():
|
|||||||
return TemplateList(data=template_strings)
|
return TemplateList(data=template_strings)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/v1/template/switch",
|
||||||
|
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
|
async def switch_template(data: TemplateSwitchRequest):
|
||||||
|
"""Switch the currently loaded template"""
|
||||||
|
if not data.name:
|
||||||
|
raise HTTPException(400, "New template name not found.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
template = get_template_from_file(data.name)
|
||||||
|
MODEL_CONTAINER.prompt_template = template
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
raise HTTPException(400, "Template does not exist. Check the name?") from e
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/v1/template/unload",
|
||||||
|
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
|
async def unload_template():
|
||||||
|
"""Unloads the currently selected template"""
|
||||||
|
|
||||||
|
MODEL_CONTAINER.prompt_template = None
|
||||||
|
|
||||||
|
|
||||||
# Lora list endpoint
|
# Lora list endpoint
|
||||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||||
|
|||||||
Reference in New Issue
Block a user