Model: Add support for HuggingFace config and bad_words_ids

This is necessary for Kobold's API. Current models use bad_words_ids
in generation_config.json, but for some reason, they're also present
in the model's config.json.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-07-26 18:23:22 -04:00
parent 545e26608f
commit 7522b1447b
4 changed files with 68 additions and 9 deletions

View File

@@ -47,7 +47,7 @@ from common.templating import (
TemplateLoadError, TemplateLoadError,
find_template_from_model, find_template_from_model,
) )
from common.transformers_utils import GenerationConfig from common.transformers_utils import GenerationConfig, HuggingFaceConfig
from common.utils import coalesce, unwrap from common.utils import coalesce, unwrap
@@ -72,6 +72,7 @@ class ExllamaV2Container:
draft_cache_mode: str = "FP16" draft_cache_mode: str = "FP16"
max_batch_size: int = 20 max_batch_size: int = 20
generation_config: Optional[GenerationConfig] = None generation_config: Optional[GenerationConfig] = None
hf_config: Optional[HuggingFaceConfig] = None
# GPU split vars # GPU split vars
gpu_split: Optional[list] = None gpu_split: Optional[list] = None
@@ -186,6 +187,9 @@ class ExllamaV2Container:
except AttributeError: except AttributeError:
pass pass
# Create the hf_config
self.hf_config = HuggingFaceConfig.from_file(model_directory)
# Then override the base_seq_len if present # Then override the base_seq_len if present
override_base_seq_len = kwargs.get("override_base_seq_len") override_base_seq_len = kwargs.get("override_base_seq_len")
if override_base_seq_len: if override_base_seq_len:
@@ -268,15 +272,8 @@ class ExllamaV2Container:
else: else:
self.cache_size = self.config.max_seq_len self.cache_size = self.config.max_seq_len
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Load generation config overrides # Load generation config overrides
generation_config_path = ( generation_config_path = model_directory / "generation_config.json"
pathlib.Path(self.config.model_dir) / "generation_config.json"
)
if generation_config_path.exists(): if generation_config_path.exists():
try: try:
self.generation_config = GenerationConfig.from_file( self.generation_config = GenerationConfig.from_file(
@@ -288,6 +285,11 @@ class ExllamaV2Container:
"Skipping generation config load because of an unexpected error." "Skipping generation config load because of an unexpected error."
) )
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Catch all for template lookup errors # Catch all for template lookup errors
if self.prompt_template: if self.prompt_template:
logger.info( logger.info(

View File

@@ -1,6 +1,7 @@
import json import json
import pathlib import pathlib
from typing import List, Optional, Union from typing import List, Optional, Union
from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
@@ -11,6 +12,7 @@ class GenerationConfig(BaseModel):
""" """
eos_token_id: Optional[Union[int, List[int]]] = None eos_token_id: Optional[Union[int, List[int]]] = None
bad_words_ids: Optional[List[List[int]]] = None
@classmethod @classmethod
def from_file(self, model_directory: pathlib.Path): def from_file(self, model_directory: pathlib.Path):
@@ -30,3 +32,40 @@ class GenerationConfig(BaseModel):
return [self.eos_token_id] return [self.eos_token_id]
else: else:
return self.eos_token_id return self.eos_token_id
class HuggingFaceConfig(BaseModel):
"""
An abridged version of HuggingFace's model config.
Will be expanded as needed.
"""
badwordsids: Optional[str] = None
@classmethod
def from_file(self, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"
with open(
hf_config_path, "r", encoding="utf8"
) as hf_config_json:
hf_config_dict = json.load(hf_config_json)
return self.model_validate(hf_config_dict)
def get_badwordsids(self):
"""Wrapper method to fetch badwordsids."""
if self.badwordsids:
try:
bad_words_list = json.loads(self.badwordsids)
return bad_words_list
except json.JSONDecodeError:
logger.warning(
"Skipping badwordsids from config.json "
"since it's not a valid array."
)
return []
else:
return []

View File

@@ -18,3 +18,9 @@ def prune_dict(input_dict):
"""Trim out instances of None from a dictionary.""" """Trim out instances of None from a dictionary."""
return {k: v for k, v in input_dict.items() if v is not None} return {k: v for k, v in input_dict.items() if v is not None}
def flat_map(input_list):
"""Flattens a list of lists into a single list."""
return [item for sublist in input_list for item in sublist]

View File

@@ -1,7 +1,9 @@
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from common import model
from common.sampling import BaseSamplerRequest from common.sampling import BaseSamplerRequest
from common.utils import flat_map, unwrap
class GenerateRequest(BaseSamplerRequest): class GenerateRequest(BaseSamplerRequest):
@@ -14,6 +16,16 @@ class GenerateRequest(BaseSamplerRequest):
if self.penalty_range == 0: if self.penalty_range == 0:
self.penalty_range = -1 self.penalty_range = -1
# Move badwordsids into banned tokens for generation
if self.use_default_badwordsids:
bad_words_ids = unwrap(
model.container.generation_config.bad_words_ids,
model.container.hf_config.get_badwordsids()
)
if bad_words_ids:
self.banned_tokens += flat_map(bad_words_ids)
return super().to_gen_params(**kwargs) return super().to_gen_params(**kwargs)