mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Merge branch 'main' into robust-length-checks
This commit is contained in:
@@ -486,16 +486,18 @@ class ExllamaV2Container:
|
||||
"rope_scale": self.config.scale_pos_emb,
|
||||
"rope_alpha": self.config.scale_alpha_value,
|
||||
"max_seq_len": self.config.max_seq_len,
|
||||
"max_batch_size": self.max_batch_size,
|
||||
"cache_size": self.cache_size,
|
||||
"cache_mode": self.cache_mode,
|
||||
"chunk_size": self.config.max_input_len,
|
||||
"num_experts_per_token": self.config.num_experts_per_token,
|
||||
"prompt_template": self.prompt_template.name
|
||||
if self.prompt_template
|
||||
else None,
|
||||
"use_vision": self.use_vision,
|
||||
}
|
||||
|
||||
if self.prompt_template:
|
||||
model_params["prompt_template"] = self.prompt_template.name
|
||||
model_params["prompt_template_content"] = self.prompt_template.raw_template
|
||||
|
||||
if self.draft_config:
|
||||
draft_model_params = {
|
||||
"name": self.draft_model_dir.name,
|
||||
@@ -759,6 +761,10 @@ class ExllamaV2Container:
|
||||
max_batch_size=self.max_batch_size,
|
||||
paged=self.paged,
|
||||
)
|
||||
|
||||
# Update the state of the container var
|
||||
if self.max_batch_size is None:
|
||||
self.max_batch_size = self.generator.generator.max_batch_size
|
||||
finally:
|
||||
# This means the generator is being recreated
|
||||
# The load lock is already released in the load function
|
||||
|
||||
@@ -14,14 +14,13 @@ class DependenciesModel(BaseModel):
|
||||
torch: bool
|
||||
exllamav2: bool
|
||||
flash_attn: bool
|
||||
outlines: bool
|
||||
infinity_emb: bool
|
||||
sentence_transformers: bool
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def extras(self) -> bool:
|
||||
return self.outlines and self.infinity_emb and self.sentence_transformers
|
||||
return self.infinity_emb and self.sentence_transformers
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
from asyncio import CancelledError
|
||||
from fastapi import HTTPException, Request
|
||||
from loguru import logger
|
||||
from sse_starlette import ServerSentEvent
|
||||
from sse_starlette.event import ServerSentEvent
|
||||
|
||||
from common import model
|
||||
from common.networking import (
|
||||
|
||||
@@ -23,9 +23,11 @@ from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadRespons
|
||||
from endpoints.core.types.model import (
|
||||
EmbeddingModelLoadRequest,
|
||||
ModelCard,
|
||||
ModelDefaultGenerationSettings,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
ModelLoadResponse,
|
||||
ModelPropsResponse,
|
||||
)
|
||||
from endpoints.core.types.health import HealthCheckResponse
|
||||
from endpoints.core.types.sampler_overrides import (
|
||||
@@ -131,6 +133,30 @@ async def current_model() -> ModelCard:
|
||||
return get_current_model()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/props", dependencies=[Depends(check_api_key), Depends(check_model_container)]
|
||||
)
|
||||
async def model_props() -> ModelPropsResponse:
|
||||
"""
|
||||
Returns specific properties of a model for clients.
|
||||
|
||||
To get all properties, use /v1/model instead.
|
||||
"""
|
||||
|
||||
current_model_card = get_current_model()
|
||||
resp = ModelPropsResponse(
|
||||
total_slots=current_model_card.parameters.max_batch_size,
|
||||
default_generation_settings=ModelDefaultGenerationSettings(
|
||||
n_ctx=current_model_card.parameters.max_seq_len,
|
||||
),
|
||||
)
|
||||
|
||||
if current_model_card.parameters.prompt_template_content:
|
||||
resp.chat_template = current_model_card.parameters.prompt_template_content
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_draft_models(request: Request) -> ModelList:
|
||||
"""
|
||||
|
||||
@@ -16,10 +16,12 @@ class ModelCardParameters(BaseModel):
|
||||
max_seq_len: Optional[int] = None
|
||||
rope_scale: Optional[float] = 1.0
|
||||
rope_alpha: Optional[float] = 1.0
|
||||
max_batch_size: Optional[int] = 1
|
||||
cache_size: Optional[int] = None
|
||||
cache_mode: Optional[str] = "FP16"
|
||||
chunk_size: Optional[int] = 2048
|
||||
prompt_template: Optional[str] = None
|
||||
prompt_template_content: Optional[str] = None
|
||||
num_experts_per_token: Optional[int] = None
|
||||
use_vision: Optional[bool] = False
|
||||
|
||||
@@ -139,3 +141,17 @@ class ModelLoadResponse(BaseModel):
|
||||
module: int
|
||||
modules: int
|
||||
status: str
|
||||
|
||||
|
||||
class ModelDefaultGenerationSettings(BaseModel):
|
||||
"""Contains default generation settings for model props."""
|
||||
|
||||
n_ctx: int
|
||||
|
||||
|
||||
class ModelPropsResponse(BaseModel):
|
||||
"""Represents a model props response."""
|
||||
|
||||
total_slots: int = 1
|
||||
chat_template: str = ""
|
||||
default_generation_settings: ModelDefaultGenerationSettings
|
||||
|
||||
@@ -16,31 +16,30 @@ version = "0.0.1"
|
||||
description = "An OAI compatible exllamav2 API that's both lightweight and fast"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"fastapi-slim >= 0.110.0",
|
||||
"fastapi-slim >= 0.115",
|
||||
"pydantic >= 2.0.0",
|
||||
"ruamel.yaml",
|
||||
"rich",
|
||||
"uvicorn >= 0.28.1",
|
||||
"jinja2 >= 3.0.0",
|
||||
"loguru",
|
||||
"sse-starlette",
|
||||
"sse-starlette >= 2.2.0",
|
||||
"packaging",
|
||||
"tokenizers>=0.21.0",
|
||||
"formatron",
|
||||
"kbnf>=0.4.1",
|
||||
"tokenizers >= 0.21.0",
|
||||
"formatron >= 0.4.10",
|
||||
"kbnf >= 0.4.1",
|
||||
"aiofiles",
|
||||
"aiohttp",
|
||||
"async_lru",
|
||||
"huggingface_hub",
|
||||
"psutil",
|
||||
"httptools>=0.5.0",
|
||||
"httptools >= 0.5.0",
|
||||
"pillow",
|
||||
|
||||
# Improved asyncio loops
|
||||
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"winloop ; platform_system == 'Windows'",
|
||||
|
||||
# TEMP: Remove once 2.x is fixed in upstream
|
||||
"numpy < 2.0.0",
|
||||
|
||||
# For python 3.12
|
||||
|
||||
Reference in New Issue
Block a user