mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 06:19:15 +00:00
Model: Fix prompt template initialization
The previous commit iterated through multiple try conditions which made it so the user has to provide a dummy prompt template. Now, template loading is fallback based. Run through a loop of functions and return if one of them succeeds. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -158,32 +158,10 @@ class ExllamaV2Container:
|
||||
self.config.set_low_mem()
|
||||
"""
|
||||
|
||||
# Set prompt template override if provided
|
||||
prompt_template_name = kwargs.get("prompt_template")
|
||||
if prompt_template_name:
|
||||
logger.info("Loading prompt template with name " f"{prompt_template_name}")
|
||||
# Read the template
|
||||
try:
|
||||
self.prompt_template = get_template_from_file(prompt_template_name)
|
||||
except FileNotFoundError:
|
||||
self.prompt_template = None
|
||||
|
||||
# Then try finding the template from the tokenizer_config.json
|
||||
try:
|
||||
self.prompt_template = get_template_from_model_json(
|
||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||
"chat_template",
|
||||
"from_tokenizer_config",
|
||||
)
|
||||
except FileNotFoundError:
|
||||
self.prompt_template = None
|
||||
|
||||
# If that fails, attempt fetching from model name
|
||||
try:
|
||||
template_match = find_template_from_model(model_directory)
|
||||
self.prompt_template = get_template_from_file(template_match)
|
||||
except (LookupError, FileNotFoundError):
|
||||
self.prompt_template = None
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
# Catch all for template lookup errors
|
||||
if self.prompt_template:
|
||||
@@ -250,6 +228,34 @@ class ExllamaV2Container:
|
||||
self.draft_config.max_input_len = kwargs["chunk_size"]
|
||||
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
|
||||
|
||||
def find_prompt_template(self, prompt_template_name, model_directory):
|
||||
"""Tries to find a prompt template using various methods"""
|
||||
|
||||
logger.info("Loading prompt template with name " f"{prompt_template_name}")
|
||||
|
||||
find_template_functions = [
|
||||
lambda: get_template_from_model_json(
|
||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||
"chat_template",
|
||||
"from_tokenizer_config",
|
||||
),
|
||||
lambda: get_template_from_file(find_template_from_model(model_directory)),
|
||||
]
|
||||
|
||||
# Add lookup from prompt template name if provided
|
||||
if prompt_template_name:
|
||||
find_template_functions.insert(
|
||||
0, lambda: get_template_from_file(prompt_template_name)
|
||||
)
|
||||
|
||||
for func in find_template_functions:
|
||||
try:
|
||||
prompt_template = func()
|
||||
if prompt_template is not None:
|
||||
return prompt_template
|
||||
except (FileNotFoundError, LookupError):
|
||||
continue
|
||||
|
||||
def calculate_rope_alpha(self, base_seq_len):
|
||||
"""Calculate the rope alpha value for a given sequence length."""
|
||||
ratio = self.config.max_seq_len / base_seq_len
|
||||
|
||||
Reference in New Issue
Block a user