Model: Add support for HuggingFace config and bad_words_ids

This is necessary for Kobold's API. Current models use bad_words_ids
in generation_config.json, but for some reason, they're also present
in the model's config.json.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-07-26 18:23:22 -04:00
parent 545e26608f
commit 7522b1447b
4 changed files with 68 additions and 9 deletions

View File

@@ -47,7 +47,7 @@ from common.templating import (
TemplateLoadError,
find_template_from_model,
)
from common.transformers_utils import GenerationConfig
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
from common.utils import coalesce, unwrap
@@ -72,6 +72,7 @@ class ExllamaV2Container:
draft_cache_mode: str = "FP16"
max_batch_size: int = 20
generation_config: Optional[GenerationConfig] = None
hf_config: Optional[HuggingFaceConfig] = None
# GPU split vars
gpu_split: Optional[list] = None
@@ -186,6 +187,9 @@ class ExllamaV2Container:
except AttributeError:
pass
# Create the hf_config
self.hf_config = HuggingFaceConfig.from_file(model_directory)
# Then override the base_seq_len if present
override_base_seq_len = kwargs.get("override_base_seq_len")
if override_base_seq_len:
@@ -268,15 +272,8 @@ class ExllamaV2Container:
else:
self.cache_size = self.config.max_seq_len
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Load generation config overrides
generation_config_path = (
pathlib.Path(self.config.model_dir) / "generation_config.json"
)
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
@@ -288,6 +285,11 @@ class ExllamaV2Container:
"Skipping generation config load because of an unexpected error."
)
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Catch all for template lookup errors
if self.prompt_template:
logger.info(