Merge branch 'main' into robust-length-checks

This commit is contained in:
kingbri
2024-12-26 18:00:26 -05:00
6 changed files with 59 additions and 13 deletions

View File

@@ -486,16 +486,18 @@ class ExllamaV2Container:
"rope_scale": self.config.scale_pos_emb,
"rope_alpha": self.config.scale_alpha_value,
"max_seq_len": self.config.max_seq_len,
"max_batch_size": self.max_batch_size,
"cache_size": self.cache_size,
"cache_mode": self.cache_mode,
"chunk_size": self.config.max_input_len,
"num_experts_per_token": self.config.num_experts_per_token,
"prompt_template": self.prompt_template.name
if self.prompt_template
else None,
"use_vision": self.use_vision,
}
if self.prompt_template:
model_params["prompt_template"] = self.prompt_template.name
model_params["prompt_template_content"] = self.prompt_template.raw_template
if self.draft_config:
draft_model_params = {
"name": self.draft_model_dir.name,
@@ -759,6 +761,10 @@ class ExllamaV2Container:
max_batch_size=self.max_batch_size,
paged=self.paged,
)
# Update the state of the container var
if self.max_batch_size is None:
self.max_batch_size = self.generator.generator.max_batch_size
finally:
# This means the generator is being recreated
# The load lock is already released in the load function

View File

@@ -14,14 +14,13 @@ class DependenciesModel(BaseModel):
torch: bool
exllamav2: bool
flash_attn: bool
outlines: bool
infinity_emb: bool
sentence_transformers: bool
@computed_field
@property
def extras(self) -> bool:
return self.outlines and self.infinity_emb and self.sentence_transformers
return self.infinity_emb and self.sentence_transformers
@computed_field
@property

View File

@@ -2,7 +2,7 @@ import asyncio
from asyncio import CancelledError
from fastapi import HTTPException, Request
from loguru import logger
from sse_starlette import ServerSentEvent
from sse_starlette.event import ServerSentEvent
from common import model
from common.networking import (

View File

@@ -23,9 +23,11 @@ from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadRespons
from endpoints.core.types.model import (
EmbeddingModelLoadRequest,
ModelCard,
ModelDefaultGenerationSettings,
ModelList,
ModelLoadRequest,
ModelLoadResponse,
ModelPropsResponse,
)
from endpoints.core.types.health import HealthCheckResponse
from endpoints.core.types.sampler_overrides import (
@@ -131,6 +133,30 @@ async def current_model() -> ModelCard:
return get_current_model()
@router.get(
"/props", dependencies=[Depends(check_api_key), Depends(check_model_container)]
)
async def model_props() -> ModelPropsResponse:
"""
Returns specific properties of a model for clients.
To get all properties, use /v1/model instead.
"""
current_model_card = get_current_model()
resp = ModelPropsResponse(
total_slots=current_model_card.parameters.max_batch_size,
default_generation_settings=ModelDefaultGenerationSettings(
n_ctx=current_model_card.parameters.max_seq_len,
),
)
if current_model_card.parameters.prompt_template_content:
resp.chat_template = current_model_card.parameters.prompt_template_content
return resp
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models(request: Request) -> ModelList:
"""

View File

@@ -16,10 +16,12 @@ class ModelCardParameters(BaseModel):
max_seq_len: Optional[int] = None
rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0
max_batch_size: Optional[int] = 1
cache_size: Optional[int] = None
cache_mode: Optional[str] = "FP16"
chunk_size: Optional[int] = 2048
prompt_template: Optional[str] = None
prompt_template_content: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_vision: Optional[bool] = False
@@ -139,3 +141,17 @@ class ModelLoadResponse(BaseModel):
module: int
modules: int
status: str
class ModelDefaultGenerationSettings(BaseModel):
"""Contains default generation settings for model props."""
n_ctx: int
class ModelPropsResponse(BaseModel):
"""Represents a model props response."""
total_slots: int = 1
chat_template: str = ""
default_generation_settings: ModelDefaultGenerationSettings

View File

@@ -16,31 +16,30 @@ version = "0.0.1"
description = "An OAI compatible exllamav2 API that's both lightweight and fast"
requires-python = ">=3.10"
dependencies = [
"fastapi-slim >= 0.110.0",
"fastapi-slim >= 0.115",
"pydantic >= 2.0.0",
"ruamel.yaml",
"rich",
"uvicorn >= 0.28.1",
"jinja2 >= 3.0.0",
"loguru",
"sse-starlette",
"sse-starlette >= 2.2.0",
"packaging",
"tokenizers>=0.21.0",
"formatron",
"kbnf>=0.4.1",
"tokenizers >= 0.21.0",
"formatron >= 0.4.10",
"kbnf >= 0.4.1",
"aiofiles",
"aiohttp",
"async_lru",
"huggingface_hub",
"psutil",
"httptools>=0.5.0",
"httptools >= 0.5.0",
"pillow",
# Improved asyncio loops
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
"winloop ; platform_system == 'Windows'",
# TEMP: Remove once 2.x is fixed in upstream
"numpy < 2.0.0",
# For python 3.12