Files
tabbyAPI/common/transformers_utils.py
2026-03-18 03:29:15 +01:00

193 lines
5.7 KiB
Python

import aiofiles
import json
import pathlib
from loguru import logger
from pydantic import BaseModel
from typing import Dict, List, Optional, Set, Union
class GenerationConfig(BaseModel):
"""
An abridged version of HuggingFace's GenerationConfig.
Will be expanded as needed.
"""
eos_token_id: Optional[Union[int, List[int]]] = None
@classmethod
async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
generation_config_path = model_directory / "generation_config.json"
async with aiofiles.open(
generation_config_path, "r", encoding="utf8"
) as generation_config_json:
contents = await generation_config_json.read()
generation_config_dict = json.loads(contents)
return cls.model_validate(generation_config_dict)
def eos_tokens(self):
"""Wrapper method to fetch EOS tokens."""
if isinstance(self.eos_token_id, list):
return self.eos_token_id
elif isinstance(self.eos_token_id, int):
return [self.eos_token_id]
else:
return []
class TextConfig(BaseModel):
max_position_embeddings: int = None
class HuggingFaceConfig(BaseModel):
"""
An abridged version of HuggingFace's model config.
Will be expanded as needed.
"""
max_position_embeddings: int = None
text_config: Optional[TextConfig] = None
eos_token_id: Optional[Union[int, List[int]]] = None
quantization_config: Optional[Dict] = None
@classmethod
async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"
async with aiofiles.open(
hf_config_path, "r", encoding="utf8"
) as hf_config_json:
contents = await hf_config_json.read()
hf_config_dict = json.loads(contents)
return cls.model_validate(hf_config_dict)
def quant_method(self):
"""Wrapper method to fetch quant type"""
if isinstance(self.quantization_config, Dict):
return self.quantization_config.get("quant_method")
else:
return None
def eos_tokens(self):
"""Wrapper method to fetch EOS tokens."""
if isinstance(self.eos_token_id, list):
return self.eos_token_id
elif isinstance(self.eos_token_id, int):
return [self.eos_token_id]
else:
return []
def get_max_position_embeddings(self, default: int | None = 4096) -> int:
if (
self.text_config is not None
and self.text_config.max_position_embeddings is not None
):
return self.text_config.max_position_embeddings
elif self.max_position_embeddings is not None:
return self.max_position_embeddings
else:
return default
class TokenizerConfig(BaseModel):
"""
An abridged version of HuggingFace's tokenizer config.
"""
add_bos_token: Optional[bool] = True
@classmethod
async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a tokenizer config file."""
tokenizer_config_path = model_directory / "tokenizer_config.json"
async with aiofiles.open(
tokenizer_config_path, "r", encoding="utf8"
) as tokenizer_config_json:
contents = await tokenizer_config_json.read()
tokenizer_config_dict = json.loads(contents)
return cls.model_validate(tokenizer_config_dict)
class HFModel:
"""
Unified container for HuggingFace model configuration files.
These are abridged for hyper-specific model parameters not covered
by most backends.
Includes:
- config.json
- generation_config.json
- tokenizer_config.json
"""
hf_config: HuggingFaceConfig
tokenizer_config: Optional[TokenizerConfig] = None
generation_config: Optional[GenerationConfig] = None
@classmethod
async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a model directory"""
self = cls()
# A model must have an HF config
try:
self.hf_config = await HuggingFaceConfig.from_directory(model_directory)
except Exception as exc:
raise ValueError(
f"Failed to load config.json from {model_directory}"
) from exc
try:
self.generation_config = await GenerationConfig.from_directory(
model_directory
)
except Exception:
logger.warning(
"Generation config file not found in model directory, skipping."
)
try:
self.tokenizer_config = await TokenizerConfig.from_directory(
model_directory
)
except Exception:
logger.warning(
"Tokenizer config file not found in model directory, skipping."
)
return self
def quant_method(self):
"""Wrapper for quantization method"""
return self.hf_config.quant_method()
def eos_tokens(self):
"""Combines and returns EOS tokens from various configs"""
eos_ids: Set[int] = set()
eos_ids.update(self.hf_config.eos_tokens())
if self.generation_config:
eos_ids.update(self.generation_config.eos_tokens())
# Convert back to a list
return list(eos_ids)
def add_bos_token(self):
"""Wrapper for tokenizer config"""
if self.tokenizer_config:
return self.tokenizer_config.add_bos_token
# Expected default
return True