mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-21 23:09:13 +00:00
Model: Create universal HFModel class
The HFModel class serves to coalesce all config files that contain random keys which are required for model usage. Adding this base class allows us to expand as HuggingFace randomly changes their JSON schemas over time, reducing the brunt that backend devs need to feel when their next model isn't supported. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,6 @@ import asyncio
|
||||
import gc
|
||||
import math
|
||||
import pathlib
|
||||
import traceback
|
||||
import torch
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
@@ -47,7 +46,7 @@ from common.health import HealthManager
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate, find_prompt_template
|
||||
from common.transformers_utils import GenerationConfig, TokenizerConfig
|
||||
from common.transformers_utils import HFModel
|
||||
from common.utils import calculate_rope_alpha, coalesce, unwrap
|
||||
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
||||
|
||||
@@ -58,6 +57,10 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
# Model directories
|
||||
model_dir: pathlib.Path = pathlib.Path("models")
|
||||
draft_model_dir: pathlib.Path = pathlib.Path("models")
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
|
||||
# HF model instance
|
||||
hf_model: HFModel
|
||||
|
||||
# Exl2 vars
|
||||
config: Optional[ExLlamaV2Config] = None
|
||||
@@ -79,8 +82,6 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
cache_mode: str = "FP16"
|
||||
draft_cache_mode: str = "FP16"
|
||||
max_batch_size: Optional[int] = None
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
tokenizer_config: Optional[TokenizerConfig] = None
|
||||
|
||||
# GPU split vars
|
||||
gpu_split: List[float] = []
|
||||
@@ -100,7 +101,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
@classmethod
|
||||
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
||||
async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
|
||||
"""
|
||||
Primary asynchronous initializer for model container.
|
||||
|
||||
@@ -114,6 +115,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
self.config = ExLlamaV2Config()
|
||||
self.model_dir = model_directory
|
||||
self.config.model_dir = str(model_directory.resolve())
|
||||
self.hf_model = hf_model
|
||||
|
||||
# Make the max seq len 4096 before preparing the config
|
||||
# This is a better default than 2048
|
||||
@@ -124,30 +126,6 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
# Check if the model arch is compatible with various exl2 features
|
||||
self.config.arch_compat_overrides()
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = await GenerationConfig.from_file(
|
||||
model_directory
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Load tokenizer config overrides
|
||||
tokenizer_config_path = model_directory / "tokenizer_config.json"
|
||||
if tokenizer_config_path.exists():
|
||||
try:
|
||||
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping tokenizer config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Set vision state and error if vision isn't supported on the current model
|
||||
self.use_vision = unwrap(kwargs.get("vision"), False)
|
||||
if self.use_vision and not self.config.vision_model_type:
|
||||
@@ -864,7 +842,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(
|
||||
kwargs.get("add_bos_token"), self.tokenizer_config.add_bos_token
|
||||
kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
|
||||
),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
embeddings=mm_embeddings_content,
|
||||
@@ -1282,16 +1260,10 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
ban_eos_token = params.ban_eos_token
|
||||
|
||||
# Set add_bos_token for generation
|
||||
add_bos_token = unwrap(
|
||||
params.add_bos_token, self.tokenizer_config.add_bos_token
|
||||
)
|
||||
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
self.generation_config.eos_tokens()
|
||||
if self.generation_config
|
||||
else [self.tokenizer.eos_token_id]
|
||||
)
|
||||
# Fetch EOS tokens from the HF model if they exist
|
||||
eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions
|
||||
# as well.
|
||||
|
||||
Reference in New Issue
Block a user