From 7522b1447b1e2d569fd9909ac91112d2308f40dd Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 18:23:22 -0400 Subject: [PATCH] Model: Add support for HuggingFace config and bad_words_ids This is necessary for Kobold's API. Current models use bad_words_ids in generation_config.json, but for some reason, they're also present in the model's config.json. Signed-off-by: kingbri --- backends/exllamav2/model.py | 20 +++++++------- common/transformers_utils.py | 39 ++++++++++++++++++++++++++++ common/utils.py | 6 +++++ endpoints/Kobold/types/generation.py | 12 +++++++++ 4 files changed, 68 insertions(+), 9 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3515123..3df16b0 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -47,7 +47,7 @@ from common.templating import ( TemplateLoadError, find_template_from_model, ) -from common.transformers_utils import GenerationConfig +from common.transformers_utils import GenerationConfig, HuggingFaceConfig from common.utils import coalesce, unwrap @@ -72,6 +72,7 @@ class ExllamaV2Container: draft_cache_mode: str = "FP16" max_batch_size: int = 20 generation_config: Optional[GenerationConfig] = None + hf_config: Optional[HuggingFaceConfig] = None # GPU split vars gpu_split: Optional[list] = None @@ -186,6 +187,9 @@ class ExllamaV2Container: except AttributeError: pass + # Create the hf_config + self.hf_config = HuggingFaceConfig.from_file(model_directory) + # Then override the base_seq_len if present override_base_seq_len = kwargs.get("override_base_seq_len") if override_base_seq_len: @@ -268,15 +272,8 @@ class ExllamaV2Container: else: self.cache_size = self.config.max_seq_len - # Try to set prompt template - self.prompt_template = self.find_prompt_template( - kwargs.get("prompt_template"), model_directory - ) - # Load generation config overrides - generation_config_path = ( - pathlib.Path(self.config.model_dir) / "generation_config.json" - ) + generation_config_path = model_directory / "generation_config.json" if generation_config_path.exists(): try: self.generation_config = GenerationConfig.from_file( @@ -288,6 +285,11 @@ class ExllamaV2Container: "Skipping generation config load because of an unexpected error." ) + # Try to set prompt template + self.prompt_template = self.find_prompt_template( + kwargs.get("prompt_template"), model_directory + ) + # Catch all for template lookup errors if self.prompt_template: logger.info( diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 62d4622..2431b1b 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -1,6 +1,7 @@ import json import pathlib from typing import List, Optional, Union +from loguru import logger from pydantic import BaseModel @@ -11,6 +12,7 @@ class GenerationConfig(BaseModel): """ eos_token_id: Optional[Union[int, List[int]]] = None + bad_words_ids: Optional[List[List[int]]] = None @classmethod def from_file(self, model_directory: pathlib.Path): @@ -30,3 +32,40 @@ class GenerationConfig(BaseModel): return [self.eos_token_id] else: return self.eos_token_id + + +class HuggingFaceConfig(BaseModel): + """ + An abridged version of HuggingFace's model config. + Will be expanded as needed. + """ + + badwordsids: Optional[str] = None + + @classmethod + def from_file(self, model_directory: pathlib.Path): + """Create an instance from a generation config file.""" + + hf_config_path = model_directory / "config.json" + with open( + hf_config_path, "r", encoding="utf8" + ) as hf_config_json: + hf_config_dict = json.load(hf_config_json) + return self.model_validate(hf_config_dict) + + def get_badwordsids(self): + """Wrapper method to fetch badwordsids.""" + + if self.badwordsids: + try: + bad_words_list = json.loads(self.badwordsids) + return bad_words_list + except json.JSONDecodeError: + logger.warning( + "Skipping badwordsids from config.json " + "since it's not a valid array." + ) + + return [] + else: + return [] diff --git a/common/utils.py b/common/utils.py index 6787f39..b120022 100644 --- a/common/utils.py +++ b/common/utils.py @@ -18,3 +18,9 @@ def prune_dict(input_dict): """Trim out instances of None from a dictionary.""" return {k: v for k, v in input_dict.items() if v is not None} + + +def flat_map(input_list): + """Flattens a list of lists into a single list.""" + + return [item for sublist in input_list for item in sublist] diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py index eab214c..0ee5489 100644 --- a/endpoints/Kobold/types/generation.py +++ b/endpoints/Kobold/types/generation.py @@ -1,7 +1,9 @@ from typing import List, Optional from pydantic import BaseModel, Field +from common import model from common.sampling import BaseSamplerRequest +from common.utils import flat_map, unwrap class GenerateRequest(BaseSamplerRequest): @@ -14,6 +16,16 @@ class GenerateRequest(BaseSamplerRequest): if self.penalty_range == 0: self.penalty_range = -1 + # Move badwordsids into banned tokens for generation + if self.use_default_badwordsids: + bad_words_ids = unwrap( + model.container.generation_config.bad_words_ids, + model.container.hf_config.get_badwordsids() + ) + + if bad_words_ids: + self.banned_tokens += flat_map(bad_words_ids) + return super().to_gen_params(**kwargs)