mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 18:21:42 +00:00
Model: Add TokenizerConfig stub and add_eos_token fallback
This stub fetches the add_eos_token field from the HF tokenizer config. Ideally, this should be in the backend rather than tabby. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -50,7 +50,7 @@ from common.health import HealthManager
|
|||||||
from common.multimodal import MultimodalEmbeddingWrapper
|
from common.multimodal import MultimodalEmbeddingWrapper
|
||||||
from common.sampling import BaseSamplerRequest
|
from common.sampling import BaseSamplerRequest
|
||||||
from common.templating import PromptTemplate, find_prompt_template
|
from common.templating import PromptTemplate, find_prompt_template
|
||||||
from common.transformers_utils import GenerationConfig
|
from common.transformers_utils import GenerationConfig, TokenizerConfig
|
||||||
from common.utils import calculate_rope_alpha, coalesce, unwrap
|
from common.utils import calculate_rope_alpha, coalesce, unwrap
|
||||||
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
||||||
|
|
||||||
@@ -80,6 +80,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||||||
draft_cache_mode: str = "FP16"
|
draft_cache_mode: str = "FP16"
|
||||||
max_batch_size: Optional[int] = None
|
max_batch_size: Optional[int] = None
|
||||||
generation_config: Optional[GenerationConfig] = None
|
generation_config: Optional[GenerationConfig] = None
|
||||||
|
tokenizer_config: Optional[TokenizerConfig] = None
|
||||||
|
|
||||||
# GPU split vars
|
# GPU split vars
|
||||||
gpu_split: List[float] = []
|
gpu_split: List[float] = []
|
||||||
@@ -130,7 +131,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||||||
if generation_config_path.exists():
|
if generation_config_path.exists():
|
||||||
try:
|
try:
|
||||||
self.generation_config = await GenerationConfig.from_file(
|
self.generation_config = await GenerationConfig.from_file(
|
||||||
generation_config_path.parent
|
model_directory
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
@@ -138,6 +139,19 @@ class ExllamaV2Container(BaseModelContainer):
|
|||||||
"Skipping generation config load because of an unexpected error."
|
"Skipping generation config load because of an unexpected error."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Load tokenizer config overrides
|
||||||
|
tokenizer_config_path = model_directory / "tokenizer_config.json"
|
||||||
|
if tokenizer_config_path.exists():
|
||||||
|
try:
|
||||||
|
self.tokenizer_config = await TokenizerConfig.from_file(
|
||||||
|
model_directory
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.warning(
|
||||||
|
"Skipping tokenizer config load because of an unexpected error."
|
||||||
|
)
|
||||||
|
|
||||||
# Set vision state and error if vision isn't supported on the current model
|
# Set vision state and error if vision isn't supported on the current model
|
||||||
self.use_vision = unwrap(kwargs.get("vision"), False)
|
self.use_vision = unwrap(kwargs.get("vision"), False)
|
||||||
if self.use_vision and not self.config.vision_model_type:
|
if self.use_vision and not self.config.vision_model_type:
|
||||||
@@ -1240,9 +1254,17 @@ class ExllamaV2Container(BaseModelContainer):
|
|||||||
) and gen_settings.token_repetition_range == -1
|
) and gen_settings.token_repetition_range == -1
|
||||||
|
|
||||||
stop_conditions = params.stop
|
stop_conditions = params.stop
|
||||||
add_bos_token = unwrap(params.add_bos_token, True)
|
|
||||||
ban_eos_token = params.ban_eos_token
|
ban_eos_token = params.ban_eos_token
|
||||||
|
|
||||||
|
|
||||||
|
print(self.tokenizer_config.add_bos_token)
|
||||||
|
# Set add_bos_token for generation
|
||||||
|
add_bos_token = coalesce(
|
||||||
|
params.add_bos_token, self.tokenizer_config.add_bos_token, True
|
||||||
|
)
|
||||||
|
|
||||||
|
print(add_bos_token)
|
||||||
|
|
||||||
# Fetch EOS tokens from generation_config if they exist
|
# Fetch EOS tokens from generation_config if they exist
|
||||||
eos_tokens = (
|
eos_tokens = (
|
||||||
self.generation_config.eos_tokens()
|
self.generation_config.eos_tokens()
|
||||||
|
|||||||
@@ -239,6 +239,7 @@ async def find_prompt_template(template_name, model_dir: pathlib.Path):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Add lookup from prompt template name if provided
|
# Add lookup from prompt template name if provided
|
||||||
|
# TODO: Possibly link to the TokenizerConfig class
|
||||||
if template_name:
|
if template_name:
|
||||||
find_template_functions[:0] = [
|
find_template_functions[:0] = [
|
||||||
lambda: PromptTemplate.from_file(pathlib.Path("templates") / template_name),
|
lambda: PromptTemplate.from_file(pathlib.Path("templates") / template_name),
|
||||||
|
|||||||
@@ -53,3 +53,23 @@ class HuggingFaceConfig(BaseModel):
|
|||||||
contents = await hf_config_json.read()
|
contents = await hf_config_json.read()
|
||||||
hf_config_dict = json.loads(contents)
|
hf_config_dict = json.loads(contents)
|
||||||
return cls.model_validate(hf_config_dict)
|
return cls.model_validate(hf_config_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
An abridged version of HuggingFace's tokenizer config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
add_bos_token: Optional[bool] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_file(cls, model_directory: pathlib.Path):
|
||||||
|
"""Create an instance from a tokenizer config file."""
|
||||||
|
|
||||||
|
tokenizer_config_path = model_directory / "tokenizer_config.json"
|
||||||
|
async with aiofiles.open(
|
||||||
|
tokenizer_config_path, "r", encoding="utf8"
|
||||||
|
) as tokenizer_config_json:
|
||||||
|
contents = await tokenizer_config_json.read()
|
||||||
|
tokenizer_config_dict = json.loads(contents)
|
||||||
|
return cls.model_validate(tokenizer_config_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user