diff --git a/backends/base_model_container.py b/backends/base_model_container.py index 5c79867..10eae0d 100644 --- a/backends/base_model_container.py +++ b/backends/base_model_container.py @@ -25,6 +25,10 @@ class BaseModelContainer(abc.ABC): prompt_template: Optional[PromptTemplate] = None generation_config: Optional[GenerationConfig] = None + # Optional features + use_draft_model: bool = False + use_vision: bool = False + # Load synchronization # The bool is a master switch for accepting requests # The lock keeps load tasks sequential @@ -65,7 +69,7 @@ class BaseModelContainer(abc.ABC): # NOTE: Might be an optional method @abc.abstractmethod - async def load_gen(self, progress_callback=None, **kwargs) -> AsyncIterator[Any]: + async def load_gen(self, progress_callback=None, **kwargs): """ Loads the model into memory, yielding progress updates. @@ -134,57 +138,6 @@ class BaseModelContainer(abc.ABC): pass - @abc.abstractmethod - async def generate( - self, - request_id: str, - prompt: str, - params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, - ) -> Dict[str, Any]: - """ - Generates a complete response for a given prompt and parameters. - - Args: - request_id: Unique identifier for the generation request. - prompt: The input prompt string. - params: Sampling and generation parameters. - abort_event: An asyncio Event to signal cancellation. - mm_embeddings: Optional multimodal embeddings. - - Returns: - A dictionary containing the generation info - """ - - pass - - @abc.abstractmethod - async def stream_generate( - self, - request_id: str, - prompt: str, - params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, - ) -> AsyncIterator[Dict[str, Any]]: - """ - Generates a response iteratively (streaming) for a given prompt. - - Args: - request_id: Unique identifier for the generation request. - prompt: The input prompt string. - params: Sampling and generation parameters. - abort_event: An asyncio Event to signal cancellation. - mm_embeddings: Optional multimodal embeddings. - - Yields: - Generation chunks - """ - - if False: - yield - @abc.abstractmethod def model_info(self) -> ModelCard: """ @@ -239,3 +192,54 @@ class BaseModelContainer(abc.ABC): """ return [] + + @abc.abstractmethod + async def generate( + self, + request_id: str, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + ) -> Dict[str, Any]: + """ + Generates a complete response for a given prompt and parameters. + + Args: + request_id: Unique identifier for the generation request. + prompt: The input prompt string. + params: Sampling and generation parameters. + abort_event: An asyncio Event to signal cancellation. + mm_embeddings: Optional multimodal embeddings. + + Returns: + A dictionary containing the generation info + """ + + pass + + @abc.abstractmethod + async def stream_generate( + self, + request_id: str, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + ) -> AsyncIterator[Dict[str, Any]]: + """ + Generates a response iteratively (streaming) for a given prompt. + + Args: + request_id: Unique identifier for the generation request. + prompt: The input prompt string. + params: Sampling and generation parameters. + abort_event: An asyncio Event to signal cancellation. + mm_embeddings: Optional multimodal embeddings. + + Yields: + Generation chunks + """ + + if False: + yield diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index b821d1a..fcf4f3c 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -64,16 +64,19 @@ class ExllamaV2Container(BaseModelContainer): # Exl2 vars config: Optional[ExLlamaV2Config] = None - draft_config: Optional[ExLlamaV2Config] = None model: Optional[ExLlamaV2] = None - draft_model: Optional[ExLlamaV2] = None cache: Optional[ExLlamaV2Cache] = None - draft_cache: Optional[ExLlamaV2Cache] = None tokenizer: Optional[ExLlamaV2Tokenizer] = None generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None prompt_template: Optional[PromptTemplate] = None paged: bool = True + # Draft model vars + use_draft_model: bool = False + draft_config: Optional[ExLlamaV2Config] = None + draft_model: Optional[ExLlamaV2] = None + draft_cache: Optional[ExLlamaV2Cache] = None + # Internal config vars cache_size: int = None cache_mode: str = "FP16" @@ -100,7 +103,7 @@ class ExllamaV2Container(BaseModelContainer): load_condition: asyncio.Condition = asyncio.Condition() @classmethod - async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): + async def create(cls, model_directory: pathlib.Path, **kwargs): """ Primary asynchronous initializer for model container. @@ -110,8 +113,6 @@ class ExllamaV2Container(BaseModelContainer): # Create a new instance as a "fake self" self = cls() - self.quiet = quiet - # Initialize config self.config = ExLlamaV2Config() self.model_dir = model_directory @@ -122,6 +123,7 @@ class ExllamaV2Container(BaseModelContainer): self.config.max_seq_len = 4096 self.config.prepare() + print(self.config.max_seq_len) # Check if the model arch is compatible with various exl2 features self.config.arch_compat_overrides() @@ -162,7 +164,7 @@ class ExllamaV2Container(BaseModelContainer): # Prepare the draft model config if necessary draft_args = unwrap(kwargs.get("draft_model"), {}) draft_model_name = draft_args.get("draft_model_name") - enable_draft = draft_args and draft_model_name + self.use_draft_model = draft_args and draft_model_name # Always disable draft if params are incorrectly configured if draft_args and draft_model_name is None: @@ -170,9 +172,9 @@ class ExllamaV2Container(BaseModelContainer): "Draft model is disabled because a model name " "wasn't provided. Please check your config.yml!" ) - enable_draft = False + self.use_draft_model = False - if enable_draft: + if self.use_draft_model: self.draft_config = ExLlamaV2Config() draft_model_path = pathlib.Path( unwrap(draft_args.get("draft_model_dir"), "models") @@ -365,7 +367,7 @@ class ExllamaV2Container(BaseModelContainer): self.config.max_attention_size = chunk_size**2 # Set user-configured draft model values - if enable_draft: + if self.use_draft_model: self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.scale_pos_emb = unwrap( diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py new file mode 100644 index 0000000..76360df --- /dev/null +++ b/backends/exllamav3/model.py @@ -0,0 +1,275 @@ +import asyncio +import gc +import pathlib +from loguru import logger +from typing import ( + Any, + AsyncIterator, + Dict, + List, + Optional, +) + +import torch + +from backends.base_model_container import BaseModelContainer +from common.concurrency import iterate_in_threadpool +from common.multimodal import MultimodalEmbeddingWrapper +from common.sampling import BaseSamplerRequest +from common.templating import PromptTemplate +from common.transformers_utils import GenerationConfig +from endpoints.core.types.model import ModelCard + +from exllamav3 import Config, Model, Cache, Tokenizer + + +class ExllamaV3Container(BaseModelContainer): + """Abstract base class for model containers.""" + + # Exposed model information + model_dir: pathlib.Path = pathlib.Path("models") + prompt_template: Optional[PromptTemplate] = None + generation_config: Optional[GenerationConfig] = None + + # Load synchronization + # The bool is a master switch for accepting requests + # The lock keeps load tasks sequential + # The condition notifies any waiting tasks + active_job_ids: Dict[str, Any] = {} + loaded: bool = False + load_lock: asyncio.Lock = asyncio.Lock() + load_condition: asyncio.Condition = asyncio.Condition() + + # Exl3 vars + model: Model + cache: Cache + tokenizer: Tokenizer + config: Config + + # Required methods + @classmethod + async def create(cls, model_directory: pathlib.Path, **kwargs): + """ + Asynchronously creates and initializes a model container instance. + + Args: + model_directory: Path to the model files. + **kwargs: Backend-specific configuration options. + + Returns: + An instance of the implementing class. + """ + + self = cls() + + logger.warning( + "ExllamaV3 is currently in an alpha state. " + "Please note that all config options may not work." + ) + + self.config = Config.from_directory(model_directory.resolve()) + self.model = Model.from_config(self.config) + self.tokenizer = Tokenizer.from_config(self.config) + + max_seq_len = kwargs.get("max_seq_len") + self.cache = Cache(self.model, max_num_tokens=max_seq_len) + + return self + + async def load(self, progress_callback=None, **kwargs): + """ + Loads the model into memory. + + Args: + progress_callback: Optional callback for progress updates. + **kwargs: Additional loading options. + """ + + async for _ in self.load_gen(progress_callback): + pass + + async def load_gen(self, progress_callback=None, **kwargs): + """ + Loads the model into memory, yielding progress updates. + + Args: + progress_callback: Optional callback for progress updates. + **kwargs: Additional loading options. + + Yields: + Progress updates + """ + + try: + await self.load_lock.acquire() + + # Wait for existing generation jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) + + generator = self.load_model_sync(progress_callback) + async for module, modules in iterate_in_threadpool(generator): + yield module, modules + + # Clean up any extra vram usage from torch and cuda + # (Helps reduce VRAM bottlenecking on Windows) + gc.collect() + torch.cuda.empty_cache() + + # Cleanup and update model load state + self.loaded = True + logger.info("Model successfully loaded.") + finally: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + + # TODO: Add draft loading + @torch.inference_mode() + def load_model_sync(self, progress_callback=None): + for value in self.model.load_gen(callback=progress_callback): + if value: + yield value + + async def unload(self, loras_only: bool = False, **kwargs): + """ + Unloads the model and associated resources from memory. + + Args: + loras_only: If True, only unload LoRAs. + **kwargs: Additional unloading options (e.g., shutdown). + """ + + try: + await self.load_lock.acquire() + + # Wait for other jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) + + self.model.unload() + self.model = None + + self.config = None + self.cache = None + self.tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + + logger.info("Model unloaded.") + finally: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + + def encode_tokens(self, text: str, **kwargs) -> List[int]: + """ + Encodes a string of text into a list of token IDs. + + Args: + text: The input text string. + **kwargs: Backend-specific encoding options (e.g., add_bos_token). + + Returns: + A list of integer token IDs. + """ + + pass + + def decode_tokens(self, ids: List[int], **kwargs) -> str: + """ + Decodes a list of token IDs back into a string. + + Args: + ids: A list of integer token IDs. + **kwargs: Backend-specific decoding options (e.g., decode_special_tokens). + + Returns: + The decoded text string. + """ + + pass + + def get_special_tokens(self, **kwargs) -> Dict[str, Any]: + """ + Gets special tokens used by the model/tokenizer. + + Args: + **kwargs: Options like add_bos_token, ban_eos_token. + + Returns: + A dictionary mapping special token names (e.g., 'bos_token', 'eos_token') + to their string or ID representation. + """ + + pass + + def model_info(self) -> ModelCard: + """ + Returns a dictionary of the current model's configuration parameters. + + Returns: + Model parameters provided by the backend + """ + + pass + + async def wait_for_jobs(self, skip_wait: bool = False): + """ + Waits for any active generation jobs to complete. + + Args: + skip_wait: If True, cancel jobs immediately instead of waiting. + """ + + pass + + async def generate( + self, + request_id: str, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + ) -> Dict[str, Any]: + """ + Generates a complete response for a given prompt and parameters. + + Args: + request_id: Unique identifier for the generation request. + prompt: The input prompt string. + params: Sampling and generation parameters. + abort_event: An asyncio Event to signal cancellation. + mm_embeddings: Optional multimodal embeddings. + + Returns: + A dictionary containing the generation info + """ + + pass + + async def stream_generate( + self, + request_id: str, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + ) -> AsyncIterator[Dict[str, Any]]: + """ + Generates a response iteratively (streaming) for a given prompt. + + Args: + request_id: Unique identifier for the generation request. + prompt: The input prompt string. + params: Sampling and generation parameters. + abort_event: An asyncio Event to signal cancellation. + mm_embeddings: Optional multimodal embeddings. + + Yields: + Generation chunks + """ + + if False: + yield diff --git a/common/model.py b/common/model.py index 96d45f6..9de86f0 100644 --- a/common/model.py +++ b/common/model.py @@ -10,7 +10,7 @@ from enum import Enum from fastapi import HTTPException from loguru import logger from ruamel.yaml import YAML -from typing import Optional +from typing import Dict, Optional from backends.base_model_container import BaseModelContainer from common.logger import get_loading_progress_bar @@ -24,7 +24,7 @@ container: Optional[BaseModelContainer] = None embeddings_container = None -_BACKEND_REGISTRY = {} +_BACKEND_REGISTRY: Dict[str, BaseModelContainer] = {} if dependencies.exllamav2: from backends.exllamav2.model import ExllamaV2Container @@ -32,6 +32,12 @@ if dependencies.exllamav2: _BACKEND_REGISTRY["exllamav2"] = ExllamaV2Container +if dependencies.exllamav3: + from backends.exllamav3.model import ExllamaV3Container + + _BACKEND_REGISTRY["exllamav3"] = ExllamaV3Container + + if dependencies.extras: from backends.infinity.model import InfinityContainer @@ -134,7 +140,9 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): "Available backends: {available_backends}" ) - new_container = await container_class.create(model_path.resolve(), False, **kwargs) + new_container: BaseModelContainer = await container_class.create( + model_path.resolve(), **kwargs + ) # Add possible types of models that can be loaded model_type = [ModelType.MODEL] @@ -142,7 +150,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): if new_container.use_vision: model_type.insert(0, ModelType.VISION) - if new_container.draft_config: + if new_container.use_draft_model: model_type.insert(0, ModelType.DRAFT) load_status = new_container.load_gen(load_progress, **kwargs) diff --git a/common/optional_dependencies.py b/common/optional_dependencies.py index 06b1286..0c1e7ff 100644 --- a/common/optional_dependencies.py +++ b/common/optional_dependencies.py @@ -13,6 +13,7 @@ class DependenciesModel(BaseModel): torch: bool exllamav2: bool + exllamav3: bool flash_attn: bool infinity_emb: bool sentence_transformers: bool @@ -25,7 +26,7 @@ class DependenciesModel(BaseModel): @computed_field @property def inference(self) -> bool: - return self.torch and self.exllamav2 and self.flash_attn + return self.torch and (self.exllamav2 or self.exllamav3) and self.flash_attn def is_installed(package_name: str) -> bool: