diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3478820..d9a52ac 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -486,16 +486,18 @@ class ExllamaV2Container: "rope_scale": self.config.scale_pos_emb, "rope_alpha": self.config.scale_alpha_value, "max_seq_len": self.config.max_seq_len, + "max_batch_size": self.max_batch_size, "cache_size": self.cache_size, "cache_mode": self.cache_mode, "chunk_size": self.config.max_input_len, "num_experts_per_token": self.config.num_experts_per_token, - "prompt_template": self.prompt_template.name - if self.prompt_template - else None, "use_vision": self.use_vision, } + if self.prompt_template: + model_params["prompt_template"] = self.prompt_template.name + model_params["prompt_template_content"] = self.prompt_template.raw_template + if self.draft_config: draft_model_params = { "name": self.draft_model_dir.name, @@ -759,6 +761,10 @@ class ExllamaV2Container: max_batch_size=self.max_batch_size, paged=self.paged, ) + + # Update the state of the container var + if self.max_batch_size is None: + self.max_batch_size = self.generator.generator.max_batch_size finally: # This means the generator is being recreated # The load lock is already released in the load function diff --git a/common/optional_dependencies.py b/common/optional_dependencies.py index e98a668..06b1286 100644 --- a/common/optional_dependencies.py +++ b/common/optional_dependencies.py @@ -14,14 +14,13 @@ class DependenciesModel(BaseModel): torch: bool exllamav2: bool flash_attn: bool - outlines: bool infinity_emb: bool sentence_transformers: bool @computed_field @property def extras(self) -> bool: - return self.outlines and self.infinity_emb and self.sentence_transformers + return self.infinity_emb and self.sentence_transformers @computed_field @property diff --git a/endpoints/Kobold/utils/generation.py b/endpoints/Kobold/utils/generation.py index 8ffbf0d..2086788 100644 --- a/endpoints/Kobold/utils/generation.py +++ b/endpoints/Kobold/utils/generation.py @@ -2,7 +2,7 @@ import asyncio from asyncio import CancelledError from fastapi import HTTPException, Request from loguru import logger -from sse_starlette import ServerSentEvent +from sse_starlette.event import ServerSentEvent from common import model from common.networking import ( diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 6b48182..6217475 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -23,9 +23,11 @@ from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadRespons from endpoints.core.types.model import ( EmbeddingModelLoadRequest, ModelCard, + ModelDefaultGenerationSettings, ModelList, ModelLoadRequest, ModelLoadResponse, + ModelPropsResponse, ) from endpoints.core.types.health import HealthCheckResponse from endpoints.core.types.sampler_overrides import ( @@ -131,6 +133,30 @@ async def current_model() -> ModelCard: return get_current_model() +@router.get( + "/props", dependencies=[Depends(check_api_key), Depends(check_model_container)] +) +async def model_props() -> ModelPropsResponse: + """ + Returns specific properties of a model for clients. + + To get all properties, use /v1/model instead. + """ + + current_model_card = get_current_model() + resp = ModelPropsResponse( + total_slots=current_model_card.parameters.max_batch_size, + default_generation_settings=ModelDefaultGenerationSettings( + n_ctx=current_model_card.parameters.max_seq_len, + ), + ) + + if current_model_card.parameters.prompt_template_content: + resp.chat_template = current_model_card.parameters.prompt_template_content + + return resp + + @router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) async def list_draft_models(request: Request) -> ModelList: """ diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index ddf1cc2..8a2e55e 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -16,10 +16,12 @@ class ModelCardParameters(BaseModel): max_seq_len: Optional[int] = None rope_scale: Optional[float] = 1.0 rope_alpha: Optional[float] = 1.0 + max_batch_size: Optional[int] = 1 cache_size: Optional[int] = None cache_mode: Optional[str] = "FP16" chunk_size: Optional[int] = 2048 prompt_template: Optional[str] = None + prompt_template_content: Optional[str] = None num_experts_per_token: Optional[int] = None use_vision: Optional[bool] = False @@ -139,3 +141,17 @@ class ModelLoadResponse(BaseModel): module: int modules: int status: str + + +class ModelDefaultGenerationSettings(BaseModel): + """Contains default generation settings for model props.""" + + n_ctx: int + + +class ModelPropsResponse(BaseModel): + """Represents a model props response.""" + + total_slots: int = 1 + chat_template: str = "" + default_generation_settings: ModelDefaultGenerationSettings diff --git a/pyproject.toml b/pyproject.toml index 5d7acb7..5b650d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,31 +16,30 @@ version = "0.0.1" description = "An OAI compatible exllamav2 API that's both lightweight and fast" requires-python = ">=3.10" dependencies = [ - "fastapi-slim >= 0.110.0", + "fastapi-slim >= 0.115", "pydantic >= 2.0.0", "ruamel.yaml", "rich", "uvicorn >= 0.28.1", "jinja2 >= 3.0.0", "loguru", - "sse-starlette", + "sse-starlette >= 2.2.0", "packaging", - "tokenizers>=0.21.0", - "formatron", - "kbnf>=0.4.1", + "tokenizers >= 0.21.0", + "formatron >= 0.4.10", + "kbnf >= 0.4.1", "aiofiles", "aiohttp", "async_lru", "huggingface_hub", "psutil", - "httptools>=0.5.0", + "httptools >= 0.5.0", "pillow", # Improved asyncio loops "uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'", "winloop ; platform_system == 'Windows'", - # TEMP: Remove once 2.x is fixed in upstream "numpy < 2.0.0", # For python 3.12