mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Model: Add support for HuggingFace config and bad_words_ids
This is necessary for Kobold's API. Current models use bad_words_ids in generation_config.json, but for some reason, they're also present in the model's config.json. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -47,7 +47,7 @@ from common.templating import (
|
||||
TemplateLoadError,
|
||||
find_template_from_model,
|
||||
)
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
|
||||
from common.utils import coalesce, unwrap
|
||||
|
||||
|
||||
@@ -72,6 +72,7 @@ class ExllamaV2Container:
|
||||
draft_cache_mode: str = "FP16"
|
||||
max_batch_size: int = 20
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
hf_config: Optional[HuggingFaceConfig] = None
|
||||
|
||||
# GPU split vars
|
||||
gpu_split: Optional[list] = None
|
||||
@@ -186,6 +187,9 @@ class ExllamaV2Container:
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Create the hf_config
|
||||
self.hf_config = HuggingFaceConfig.from_file(model_directory)
|
||||
|
||||
# Then override the base_seq_len if present
|
||||
override_base_seq_len = kwargs.get("override_base_seq_len")
|
||||
if override_base_seq_len:
|
||||
@@ -268,15 +272,8 @@ class ExllamaV2Container:
|
||||
else:
|
||||
self.cache_size = self.config.max_seq_len
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = (
|
||||
pathlib.Path(self.config.model_dir) / "generation_config.json"
|
||||
)
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_file(
|
||||
@@ -288,6 +285,11 @@ class ExllamaV2Container:
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
# Catch all for template lookup errors
|
||||
if self.prompt_template:
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user