Model: Fix prompt template initialization

The previous commit iterated through multiple try conditions which
made it so the user has to provide a dummy prompt template. Now,
template loading is fallback based.

Run through a loop of functions and return if one of them succeeds.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-01-24 23:36:35 -05:00
committed by Brian Dashore
parent 740b0215dd
commit 90fb41a77a

View File

@@ -158,32 +158,10 @@ class ExllamaV2Container:
self.config.set_low_mem()
"""
# Set prompt template override if provided
prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name:
logger.info("Loading prompt template with name " f"{prompt_template_name}")
# Read the template
try:
self.prompt_template = get_template_from_file(prompt_template_name)
except FileNotFoundError:
self.prompt_template = None
# Then try finding the template from the tokenizer_config.json
try:
self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_tokenizer_config",
)
except FileNotFoundError:
self.prompt_template = None
# If that fails, attempt fetching from model name
try:
template_match = find_template_from_model(model_directory)
self.prompt_template = get_template_from_file(template_match)
except (LookupError, FileNotFoundError):
self.prompt_template = None
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Catch all for template lookup errors
if self.prompt_template:
@@ -250,6 +228,34 @@ class ExllamaV2Container:
self.draft_config.max_input_len = kwargs["chunk_size"]
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
def find_prompt_template(self, prompt_template_name, model_directory):
"""Tries to find a prompt template using various methods"""
logger.info("Loading prompt template with name " f"{prompt_template_name}")
find_template_functions = [
lambda: get_template_from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_tokenizer_config",
),
lambda: get_template_from_file(find_template_from_model(model_directory)),
]
# Add lookup from prompt template name if provided
if prompt_template_name:
find_template_functions.insert(
0, lambda: get_template_from_file(prompt_template_name)
)
for func in find_template_functions:
try:
prompt_template = func()
if prompt_template is not None:
return prompt_template
except (FileNotFoundError, LookupError):
continue
def calculate_rope_alpha(self, base_seq_len):
"""Calculate the rope alpha value for a given sequence length."""
ratio = self.config.max_seq_len / base_seq_len