mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
292 lines
8.5 KiB
Python
292 lines
8.5 KiB
Python
import asyncio
|
|
import gc
|
|
import pathlib
|
|
from loguru import logger
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
)
|
|
|
|
import torch
|
|
|
|
from backends.base_model_container import BaseModelContainer
|
|
from common.concurrency import iterate_in_threadpool
|
|
from common.multimodal import MultimodalEmbeddingWrapper
|
|
from common.sampling import BaseSamplerRequest
|
|
from common.templating import PromptTemplate
|
|
from common.transformers_utils import GenerationConfig
|
|
from common.utils import unwrap
|
|
from endpoints.core.types.model import ModelCard
|
|
|
|
from exllamav3 import Config, Model, Cache, Tokenizer
|
|
|
|
|
|
class ExllamaV3Container(BaseModelContainer):
|
|
"""Abstract base class for model containers."""
|
|
|
|
# Exposed model information
|
|
model_dir: pathlib.Path = pathlib.Path("models")
|
|
prompt_template: Optional[PromptTemplate] = None
|
|
generation_config: Optional[GenerationConfig] = None
|
|
|
|
# Load synchronization
|
|
# The bool is a master switch for accepting requests
|
|
# The lock keeps load tasks sequential
|
|
# The condition notifies any waiting tasks
|
|
active_job_ids: Dict[str, Any] = {}
|
|
loaded: bool = False
|
|
load_lock: asyncio.Lock = asyncio.Lock()
|
|
load_condition: asyncio.Condition = asyncio.Condition()
|
|
|
|
# Exl3 vars
|
|
model: Model
|
|
cache: Cache
|
|
tokenizer: Tokenizer
|
|
config: Config
|
|
|
|
# Required methods
|
|
@classmethod
|
|
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
|
"""
|
|
Asynchronously creates and initializes a model container instance.
|
|
|
|
Args:
|
|
model_directory: Path to the model files.
|
|
**kwargs: Backend-specific configuration options.
|
|
|
|
Returns:
|
|
An instance of the implementing class.
|
|
"""
|
|
|
|
self = cls()
|
|
|
|
logger.warning(
|
|
"ExllamaV3 is currently in an alpha state. "
|
|
"Please note that all config options may not work."
|
|
)
|
|
|
|
self.config = Config.from_directory(model_directory.resolve())
|
|
self.model = Model.from_config(self.config)
|
|
self.tokenizer = Tokenizer.from_config(self.config)
|
|
|
|
max_seq_len = kwargs.get("max_seq_len")
|
|
self.cache = Cache(self.model, max_num_tokens=max_seq_len)
|
|
|
|
return self
|
|
|
|
async def load(self, progress_callback=None, **kwargs):
|
|
"""
|
|
Loads the model into memory.
|
|
|
|
Args:
|
|
progress_callback: Optional callback for progress updates.
|
|
**kwargs: Additional loading options.
|
|
"""
|
|
|
|
async for _ in self.load_gen(progress_callback):
|
|
pass
|
|
|
|
async def load_gen(self, progress_callback=None, **kwargs):
|
|
"""
|
|
Loads the model into memory, yielding progress updates.
|
|
|
|
Args:
|
|
progress_callback: Optional callback for progress updates.
|
|
**kwargs: Additional loading options.
|
|
|
|
Yields:
|
|
Progress updates
|
|
"""
|
|
|
|
try:
|
|
await self.load_lock.acquire()
|
|
|
|
# Wait for existing generation jobs to finish
|
|
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
|
|
|
generator = self.load_model_sync(progress_callback)
|
|
async for module, modules in iterate_in_threadpool(generator):
|
|
yield module, modules
|
|
|
|
# Clean up any extra vram usage from torch and cuda
|
|
# (Helps reduce VRAM bottlenecking on Windows)
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Cleanup and update model load state
|
|
self.loaded = True
|
|
logger.info("Model successfully loaded.")
|
|
finally:
|
|
self.load_lock.release()
|
|
|
|
async with self.load_condition:
|
|
self.load_condition.notify_all()
|
|
|
|
# TODO: Add draft loading
|
|
@torch.inference_mode()
|
|
def load_model_sync(self, progress_callback=None):
|
|
for value in self.model.load_gen(callback=progress_callback):
|
|
if value:
|
|
yield value
|
|
|
|
async def unload(self, loras_only: bool = False, **kwargs):
|
|
"""
|
|
Unloads the model and associated resources from memory.
|
|
|
|
Args:
|
|
loras_only: If True, only unload LoRAs.
|
|
**kwargs: Additional unloading options (e.g., shutdown).
|
|
"""
|
|
|
|
try:
|
|
await self.load_lock.acquire()
|
|
|
|
# Wait for other jobs to finish
|
|
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
|
|
|
self.model.unload()
|
|
self.model = None
|
|
|
|
self.config = None
|
|
self.cache = None
|
|
self.tokenizer = None
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
logger.info("Model unloaded.")
|
|
finally:
|
|
self.load_lock.release()
|
|
|
|
async with self.load_condition:
|
|
self.load_condition.notify_all()
|
|
|
|
def encode_tokens(self, text: str, **kwargs) -> List[int]:
|
|
"""
|
|
Encodes a string of text into a list of token IDs.
|
|
|
|
Args:
|
|
text: The input text string.
|
|
**kwargs: Backend-specific encoding options (e.g., add_bos_token).
|
|
|
|
Returns:
|
|
A list of integer token IDs.
|
|
"""
|
|
|
|
return self.tokenizer.encode(
|
|
text,
|
|
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
|
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
|
).flatten().tolist()
|
|
|
|
def decode_tokens(self, ids: List[int], **kwargs) -> str:
|
|
"""
|
|
Decodes a list of token IDs back into a string.
|
|
|
|
Args:
|
|
ids: A list of integer token IDs.
|
|
**kwargs: Backend-specific decoding options (e.g., decode_special_tokens).
|
|
|
|
Returns:
|
|
The decoded text string.
|
|
"""
|
|
|
|
ids = torch.tensor([ids])
|
|
return self.tokenizer.decode(
|
|
ids,
|
|
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
|
)[0]
|
|
|
|
def get_special_tokens(
|
|
self, add_bos_token: bool = True, ban_eos_token: bool = False
|
|
):
|
|
"""
|
|
Gets special tokens used by the model/tokenizer.
|
|
|
|
Args:
|
|
**kwargs: Options like add_bos_token, ban_eos_token.
|
|
|
|
Returns:
|
|
A dictionary mapping special token names (e.g., 'bos_token', 'eos_token')
|
|
to their string or ID representation.
|
|
"""
|
|
|
|
return {
|
|
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
|
|
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
|
|
"pad_token": self.tokenizer.pad_token,
|
|
"unk_token": self.tokenizer.unk_token,
|
|
}
|
|
|
|
def model_info(self) -> ModelCard:
|
|
"""
|
|
Returns a dictionary of the current model's configuration parameters.
|
|
|
|
Returns:
|
|
Model parameters provided by the backend
|
|
"""
|
|
|
|
pass
|
|
|
|
async def wait_for_jobs(self, skip_wait: bool = False):
|
|
"""
|
|
Waits for any active generation jobs to complete.
|
|
|
|
Args:
|
|
skip_wait: If True, cancel jobs immediately instead of waiting.
|
|
"""
|
|
|
|
pass
|
|
|
|
async def generate(
|
|
self,
|
|
request_id: str,
|
|
prompt: str,
|
|
params: BaseSamplerRequest,
|
|
abort_event: Optional[asyncio.Event] = None,
|
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Generates a complete response for a given prompt and parameters.
|
|
|
|
Args:
|
|
request_id: Unique identifier for the generation request.
|
|
prompt: The input prompt string.
|
|
params: Sampling and generation parameters.
|
|
abort_event: An asyncio Event to signal cancellation.
|
|
mm_embeddings: Optional multimodal embeddings.
|
|
|
|
Returns:
|
|
A dictionary containing the generation info
|
|
"""
|
|
|
|
pass
|
|
|
|
async def stream_generate(
|
|
self,
|
|
request_id: str,
|
|
prompt: str,
|
|
params: BaseSamplerRequest,
|
|
abort_event: Optional[asyncio.Event] = None,
|
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
|
) -> AsyncIterator[Dict[str, Any]]:
|
|
"""
|
|
Generates a response iteratively (streaming) for a given prompt.
|
|
|
|
Args:
|
|
request_id: Unique identifier for the generation request.
|
|
prompt: The input prompt string.
|
|
params: Sampling and generation parameters.
|
|
abort_event: An asyncio Event to signal cancellation.
|
|
mm_embeddings: Optional multimodal embeddings.
|
|
|
|
Yields:
|
|
Generation chunks
|
|
"""
|
|
|
|
if False:
|
|
yield
|