diff --git a/OAI/types/template.py b/OAI/types/template.py index 0374547..d72d621 100644 --- a/OAI/types/template.py +++ b/OAI/types/template.py @@ -7,3 +7,9 @@ class TemplateList(BaseModel): object: str = "list" data: List[str] = Field(default_factory=list) + + +class TemplateSwitchRequest(BaseModel): + """Request to switch a template.""" + + name: str diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6ee186b..0f91b05 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -163,30 +163,27 @@ class ExllamaV2Container: if prompt_template_name: logger.info("Loading prompt template with name " f"{prompt_template_name}") # Read the template - self.prompt_template = get_template_from_file(prompt_template_name) - else: - # Then try finding the template from the tokenizer_config.json - self.prompt_template = get_template_from_model_json( - pathlib.Path(self.config.model_dir) / "tokenizer_config.json", - "chat_template", - "from_tokenizer_config", - ) + try: + self.prompt_template = get_template_from_file(prompt_template_name) + except FileNotFoundError: + self.prompt_template = None - # Try finding the chat template from the model's config.json - # TODO: This may not even be used with huggingface models, - # mark for removal. - if self.prompt_template is None: + # Then try finding the template from the tokenizer_config.json + try: self.prompt_template = get_template_from_model_json( - pathlib.Path(self.config.model_config), + pathlib.Path(self.config.model_dir) / "tokenizer_config.json", "chat_template", - "from_model_config", + "from_tokenizer_config", ) + except FileNotFoundError: + self.prompt_template = None # If that fails, attempt fetching from model name - if self.prompt_template is None: + try: template_match = find_template_from_model(model_directory) - if template_match: - self.prompt_template = get_template_from_file(template_match) + self.prompt_template = get_template_from_file(template_match) + except (LookupError, FileNotFoundError): + self.prompt_template = None # Catch all for template lookup errors if self.prompt_template: diff --git a/common/templating.py b/common/templating.py index ddc0ca1..bbd0596 100644 --- a/common/templating.py +++ b/common/templating.py @@ -68,26 +68,29 @@ def find_template_from_model(model_path: pathlib.Path): """Find a matching template name from a model path.""" model_name = model_path.name template_files = get_all_templates() + for filepath in template_files: template_name = filepath.stem.lower() # Check if the template name is present in the model name if template_name in model_name.lower(): return template_name - - return None + else: + raise LookupError("Could not find template from model name.") def get_template_from_file(prompt_template_name: str): """Get a template from a jinja file.""" + template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja") if template_path.exists(): with open(template_path, "r", encoding="utf8") as raw_template: return PromptTemplate( name=prompt_template_name, template=raw_template.read() ) - - return None + else: + # Let the user know if the template file isn't found + raise FileNotFoundError(f'Template "{prompt_template_name}" not found.') # Get a template from a JSON file @@ -100,5 +103,5 @@ def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str): chat_template = model_config.get(key) if chat_template: return PromptTemplate(name=name, template=chat_template) - - return None + else: + raise FileNotFoundError(f'Model JSON path "{json_path}" not found.') diff --git a/main.py b/main.py index 218d9c0..df0d13d 100644 --- a/main.py +++ b/main.py @@ -27,7 +27,11 @@ from common.config import ( ) from common.generators import call_with_semaphore, generate_with_semaphore from common.sampling import get_overrides_from_file -from common.templating import get_all_templates, get_prompt_from_template +from common.templating import ( + get_all_templates, + get_prompt_from_template, + get_template_from_file, +) from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap from common.logger import init_logger from OAI.types.completion import CompletionRequest @@ -39,7 +43,7 @@ from OAI.types.model import ( ModelLoadResponse, ModelCardParameters, ) -from OAI.types.template import TemplateList +from OAI.types.template import TemplateList, TemplateSwitchRequest from OAI.types.token import ( TokenEncodeRequest, TokenEncodeResponse, @@ -258,6 +262,32 @@ async def get_templates(): return TemplateList(data=template_strings) +@app.post( + "/v1/template/switch", + dependencies=[Depends(check_admin_key), Depends(_check_model_container)], +) +async def switch_template(data: TemplateSwitchRequest): + """Switch the currently loaded template""" + if not data.name: + raise HTTPException(400, "New template name not found.") + + try: + template = get_template_from_file(data.name) + MODEL_CONTAINER.prompt_template = template + except FileNotFoundError as e: + raise HTTPException(400, "Template does not exist. Check the name?") from e + + +@app.post( + "/v1/template/unload", + dependencies=[Depends(check_admin_key), Depends(_check_model_container)], +) +async def unload_template(): + """Unloads the currently selected template""" + + MODEL_CONTAINER.prompt_template = None + + # Lora list endpoint @app.get("/v1/loras", dependencies=[Depends(check_api_key)]) @app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])