mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-23 15:59:14 +00:00
Model: Create universal HFModel class
The HFModel class serves to coalesce all config files that contain random keys which are required for model usage. Adding this base class allows us to expand as HuggingFace randomly changes their JSON schemas over time, reducing the brunt that backend devs need to feel when their next model isn't supported. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,6 @@ import asyncio
|
||||
import gc
|
||||
import pathlib
|
||||
import re
|
||||
import traceback
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
@@ -35,7 +34,7 @@ from common.health import HealthManager
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate, find_prompt_template
|
||||
from common.transformers_utils import GenerationConfig, TokenizerConfig
|
||||
from common.transformers_utils import HFModel
|
||||
from common.utils import coalesce, unwrap
|
||||
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
||||
|
||||
@@ -46,7 +45,9 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
# Exposed model information
|
||||
model_dir: pathlib.Path = pathlib.Path("models")
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
# HF Model instance
|
||||
hf_model: HFModel
|
||||
|
||||
# Load synchronization
|
||||
# The bool is a master switch for accepting requests
|
||||
@@ -58,15 +59,14 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
# Exl3 vars
|
||||
model: Optional[Model]
|
||||
cache: Optional[Cache]
|
||||
draft_model: Optional[Model]
|
||||
draft_cache: Optional[Cache]
|
||||
tokenizer: Optional[Tokenizer]
|
||||
config: Optional[Config]
|
||||
draft_config: Optional[Config]
|
||||
generator: Optional[AsyncGenerator]
|
||||
tokenizer_config: Optional[TokenizerConfig]
|
||||
model: Optional[Model] = None
|
||||
cache: Optional[Cache] = None
|
||||
draft_model: Optional[Model] = None
|
||||
draft_cache: Optional[Cache] = None
|
||||
tokenizer: Optional[Tokenizer] = None
|
||||
config: Optional[Config] = None
|
||||
draft_config: Optional[Config] = None
|
||||
generator: Optional[AsyncGenerator] = None
|
||||
|
||||
# Class-specific vars
|
||||
gpu_split: List[float] | None = None
|
||||
@@ -82,7 +82,7 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
|
||||
# Required methods
|
||||
@classmethod
|
||||
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
||||
async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
|
||||
"""
|
||||
Asynchronously creates and initializes a model container instance.
|
||||
|
||||
@@ -96,50 +96,17 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
|
||||
self = cls()
|
||||
|
||||
self.model = None
|
||||
self.cache = None
|
||||
self.draft_model = None
|
||||
self.draft_cache = None
|
||||
self.tokenizer = None
|
||||
self.config = None
|
||||
self.draft_config = None
|
||||
self.generator = None
|
||||
self.tokenizer_config = None
|
||||
|
||||
logger.warning(
|
||||
"ExllamaV3 is currently in an alpha state. "
|
||||
"Please note that all config options may not work."
|
||||
)
|
||||
|
||||
self.model_dir = model_directory
|
||||
self.hf_model = hf_model
|
||||
self.config = Config.from_directory(str(model_directory.resolve()))
|
||||
self.model = Model.from_config(self.config)
|
||||
self.tokenizer = Tokenizer.from_config(self.config)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = await GenerationConfig.from_file(
|
||||
model_directory
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Load tokenizer config overrides
|
||||
tokenizer_config_path = model_directory / "tokenizer_config.json"
|
||||
if tokenizer_config_path.exists():
|
||||
try:
|
||||
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping tokenizer config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Fallback to 4096 since exl3 can't fetch from HF's config.json
|
||||
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||
|
||||
@@ -554,7 +521,9 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
return (
|
||||
self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
add_bos=unwrap(
|
||||
kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
|
||||
),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
)
|
||||
.flatten()
|
||||
@@ -822,16 +791,10 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
|
||||
prompts = [prompt]
|
||||
stop_conditions = params.stop
|
||||
add_bos_token = unwrap(
|
||||
params.add_bos_token, self.tokenizer_config.add_bos_token
|
||||
)
|
||||
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
self.generation_config.eos_tokens()
|
||||
if self.generation_config
|
||||
else [self.tokenizer.eos_token_id]
|
||||
)
|
||||
eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
|
||||
|
||||
stop_conditions += eos_tokens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user