mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 06:19:15 +00:00
Model: Add EOS token support from generation_config.json
GenerationConfig is meant to override various parts of the model on generation within the transformers lib. Rather than implementing the entire GenerationConfig framework (since it's pretty redundant), add in multi eos_token support like VLLM. The GenerationConfig is used only for generation, but can be used for other uses if needed. If there's more necessary parameters in the future, add those in as well. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -36,6 +36,7 @@ from common.templating import (
|
||||
get_template_from_model_json,
|
||||
get_template_from_file,
|
||||
)
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.utils import coalesce, unwrap
|
||||
|
||||
|
||||
@@ -57,6 +58,7 @@ class ExllamaV2Container:
|
||||
# Internal config vars
|
||||
cache_mode: str = "FP16"
|
||||
use_cfg: bool = False
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
# GPU split vars
|
||||
gpu_split: Optional[list] = None
|
||||
@@ -193,6 +195,21 @@ class ExllamaV2Container:
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = (
|
||||
pathlib.Path(self.config.model_dir) / "generation_config.json"
|
||||
)
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_file(
|
||||
generation_config_path.parent
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Catch all for template lookup errors
|
||||
if self.prompt_template:
|
||||
logger.info(
|
||||
@@ -566,6 +583,7 @@ class ExllamaV2Container:
|
||||
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||
)[0]
|
||||
|
||||
# TODO: Maybe support generation_config for eos_token
|
||||
def get_special_tokens(
|
||||
self, add_bos_token: bool = True, ban_eos_token: bool = False
|
||||
):
|
||||
@@ -840,13 +858,20 @@ class ExllamaV2Container:
|
||||
grammar_string, gen_settings, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
self.generation_config.eos_tokens()
|
||||
if self.generation_config
|
||||
else [self.tokenizer.eos_token_id]
|
||||
)
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions
|
||||
# as well.
|
||||
# Set this below logging to avoid polluting the stop strings array
|
||||
if ban_eos_token:
|
||||
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
||||
gen_settings.disallow_tokens(self.tokenizer, eos_tokens)
|
||||
else:
|
||||
stop_conditions.append(self.tokenizer.eos_token_id)
|
||||
stop_conditions += eos_tokens
|
||||
|
||||
# Stop conditions
|
||||
self.generator.set_stop_conditions(stop_conditions)
|
||||
@@ -891,6 +916,8 @@ class ExllamaV2Container:
|
||||
token_healing=token_healing,
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
generate_window=generate_window,
|
||||
bos_token_id=self.tokenizer.bos_token_id,
|
||||
eos_token_id=eos_tokens,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
speculative_ngram=self.generator.speculative_ngram,
|
||||
|
||||
Reference in New Issue
Block a user