mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-29 10:42:03 +00:00
Model: Add support for HuggingFace config and bad_words_ids
This is necessary for Kobold's API. Current models use bad_words_ids in generation_config.json, but for some reason, they're also present in the model's config.json. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -47,7 +47,7 @@ from common.templating import (
|
|||||||
TemplateLoadError,
|
TemplateLoadError,
|
||||||
find_template_from_model,
|
find_template_from_model,
|
||||||
)
|
)
|
||||||
from common.transformers_utils import GenerationConfig
|
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
|
||||||
from common.utils import coalesce, unwrap
|
from common.utils import coalesce, unwrap
|
||||||
|
|
||||||
|
|
||||||
@@ -72,6 +72,7 @@ class ExllamaV2Container:
|
|||||||
draft_cache_mode: str = "FP16"
|
draft_cache_mode: str = "FP16"
|
||||||
max_batch_size: int = 20
|
max_batch_size: int = 20
|
||||||
generation_config: Optional[GenerationConfig] = None
|
generation_config: Optional[GenerationConfig] = None
|
||||||
|
hf_config: Optional[HuggingFaceConfig] = None
|
||||||
|
|
||||||
# GPU split vars
|
# GPU split vars
|
||||||
gpu_split: Optional[list] = None
|
gpu_split: Optional[list] = None
|
||||||
@@ -186,6 +187,9 @@ class ExllamaV2Container:
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Create the hf_config
|
||||||
|
self.hf_config = HuggingFaceConfig.from_file(model_directory)
|
||||||
|
|
||||||
# Then override the base_seq_len if present
|
# Then override the base_seq_len if present
|
||||||
override_base_seq_len = kwargs.get("override_base_seq_len")
|
override_base_seq_len = kwargs.get("override_base_seq_len")
|
||||||
if override_base_seq_len:
|
if override_base_seq_len:
|
||||||
@@ -268,15 +272,8 @@ class ExllamaV2Container:
|
|||||||
else:
|
else:
|
||||||
self.cache_size = self.config.max_seq_len
|
self.cache_size = self.config.max_seq_len
|
||||||
|
|
||||||
# Try to set prompt template
|
|
||||||
self.prompt_template = self.find_prompt_template(
|
|
||||||
kwargs.get("prompt_template"), model_directory
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load generation config overrides
|
# Load generation config overrides
|
||||||
generation_config_path = (
|
generation_config_path = model_directory / "generation_config.json"
|
||||||
pathlib.Path(self.config.model_dir) / "generation_config.json"
|
|
||||||
)
|
|
||||||
if generation_config_path.exists():
|
if generation_config_path.exists():
|
||||||
try:
|
try:
|
||||||
self.generation_config = GenerationConfig.from_file(
|
self.generation_config = GenerationConfig.from_file(
|
||||||
@@ -288,6 +285,11 @@ class ExllamaV2Container:
|
|||||||
"Skipping generation config load because of an unexpected error."
|
"Skipping generation config load because of an unexpected error."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Try to set prompt template
|
||||||
|
self.prompt_template = self.find_prompt_template(
|
||||||
|
kwargs.get("prompt_template"), model_directory
|
||||||
|
)
|
||||||
|
|
||||||
# Catch all for template lookup errors
|
# Catch all for template lookup errors
|
||||||
if self.prompt_template:
|
if self.prompt_template:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
from loguru import logger
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@@ -11,6 +12,7 @@ class GenerationConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None
|
eos_token_id: Optional[Union[int, List[int]]] = None
|
||||||
|
bad_words_ids: Optional[List[List[int]]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(self, model_directory: pathlib.Path):
|
def from_file(self, model_directory: pathlib.Path):
|
||||||
@@ -30,3 +32,40 @@ class GenerationConfig(BaseModel):
|
|||||||
return [self.eos_token_id]
|
return [self.eos_token_id]
|
||||||
else:
|
else:
|
||||||
return self.eos_token_id
|
return self.eos_token_id
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
An abridged version of HuggingFace's model config.
|
||||||
|
Will be expanded as needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
badwordsids: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(self, model_directory: pathlib.Path):
|
||||||
|
"""Create an instance from a generation config file."""
|
||||||
|
|
||||||
|
hf_config_path = model_directory / "config.json"
|
||||||
|
with open(
|
||||||
|
hf_config_path, "r", encoding="utf8"
|
||||||
|
) as hf_config_json:
|
||||||
|
hf_config_dict = json.load(hf_config_json)
|
||||||
|
return self.model_validate(hf_config_dict)
|
||||||
|
|
||||||
|
def get_badwordsids(self):
|
||||||
|
"""Wrapper method to fetch badwordsids."""
|
||||||
|
|
||||||
|
if self.badwordsids:
|
||||||
|
try:
|
||||||
|
bad_words_list = json.loads(self.badwordsids)
|
||||||
|
return bad_words_list
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping badwordsids from config.json "
|
||||||
|
"since it's not a valid array."
|
||||||
|
)
|
||||||
|
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|||||||
@@ -18,3 +18,9 @@ def prune_dict(input_dict):
|
|||||||
"""Trim out instances of None from a dictionary."""
|
"""Trim out instances of None from a dictionary."""
|
||||||
|
|
||||||
return {k: v for k, v in input_dict.items() if v is not None}
|
return {k: v for k, v in input_dict.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
def flat_map(input_list):
|
||||||
|
"""Flattens a list of lists into a single list."""
|
||||||
|
|
||||||
|
return [item for sublist in input_list for item in sublist]
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from common import model
|
||||||
from common.sampling import BaseSamplerRequest
|
from common.sampling import BaseSamplerRequest
|
||||||
|
from common.utils import flat_map, unwrap
|
||||||
|
|
||||||
|
|
||||||
class GenerateRequest(BaseSamplerRequest):
|
class GenerateRequest(BaseSamplerRequest):
|
||||||
@@ -14,6 +16,16 @@ class GenerateRequest(BaseSamplerRequest):
|
|||||||
if self.penalty_range == 0:
|
if self.penalty_range == 0:
|
||||||
self.penalty_range = -1
|
self.penalty_range = -1
|
||||||
|
|
||||||
|
# Move badwordsids into banned tokens for generation
|
||||||
|
if self.use_default_badwordsids:
|
||||||
|
bad_words_ids = unwrap(
|
||||||
|
model.container.generation_config.bad_words_ids,
|
||||||
|
model.container.hf_config.get_badwordsids()
|
||||||
|
)
|
||||||
|
|
||||||
|
if bad_words_ids:
|
||||||
|
self.banned_tokens += flat_map(bad_words_ids)
|
||||||
|
|
||||||
return super().to_gen_params(**kwargs)
|
return super().to_gen_params(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user