mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-26 09:18:53 +00:00
API + Model: Apply config.yml defaults for all load paths
There are two ways to load a model: 1. Via the load endpoint 2. Inline with a completion The defaults were not applying on the inline load, so rewrite to fix that. However, while doing this, set up a defaults dictionary rather than comparing it at runtime and remove the pydantic default lambda on all the model load fields. This makes the code cleaner and establishes a clear config tree for loading models. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -13,7 +13,6 @@ from typing import Optional
|
|||||||
from common.logger import get_loading_progress_bar
|
from common.logger import get_loading_progress_bar
|
||||||
from common.networking import handle_request_error
|
from common.networking import handle_request_error
|
||||||
from common.tabby_config import config
|
from common.tabby_config import config
|
||||||
from common.utils import unwrap
|
|
||||||
from endpoints.utils import do_export_openapi
|
from endpoints.utils import do_export_openapi
|
||||||
|
|
||||||
if not do_export_openapi:
|
if not do_export_openapi:
|
||||||
@@ -67,6 +66,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||||||
logger.info("Unloading existing model.")
|
logger.info("Unloading existing model.")
|
||||||
await unload_model()
|
await unload_model()
|
||||||
|
|
||||||
|
# Merge with config defaults
|
||||||
|
kwargs = {**config.model_defaults, **kwargs}
|
||||||
|
|
||||||
|
# Create a new container
|
||||||
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
||||||
|
|
||||||
model_type = "draft" if container.draft_config else "model"
|
model_type = "draft" if container.draft_config else "model"
|
||||||
@@ -149,25 +152,6 @@ async def unload_embedding_model():
|
|||||||
embeddings_container = None
|
embeddings_container = None
|
||||||
|
|
||||||
|
|
||||||
# FIXME: Maybe make this a one-time function instead of a dynamic default
|
|
||||||
def get_config_default(key: str, model_type: str = "model"):
|
|
||||||
"""Fetches a default value from model config if allowed by the user."""
|
|
||||||
|
|
||||||
default_keys = unwrap(config.model.get("use_as_default"), [])
|
|
||||||
|
|
||||||
# Add extra keys to defaults
|
|
||||||
default_keys.append("embeddings_device")
|
|
||||||
|
|
||||||
if key in default_keys:
|
|
||||||
# Is this a draft model load parameter?
|
|
||||||
if model_type == "draft":
|
|
||||||
return config.draft_model.get(key)
|
|
||||||
elif model_type == "embedding":
|
|
||||||
return config.embeddings.get(key)
|
|
||||||
else:
|
|
||||||
return config.model.get(key)
|
|
||||||
|
|
||||||
|
|
||||||
async def check_model_container():
|
async def check_model_container():
|
||||||
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
|
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ from common.utils import unwrap, merge_dicts
|
|||||||
|
|
||||||
|
|
||||||
class TabbyConfig:
|
class TabbyConfig:
|
||||||
|
"""Common config class for TabbyAPI. Loaded into sub-dictionaries from YAML file."""
|
||||||
|
|
||||||
|
# Sub-blocks of yaml
|
||||||
network: dict = {}
|
network: dict = {}
|
||||||
logging: dict = {}
|
logging: dict = {}
|
||||||
model: dict = {}
|
model: dict = {}
|
||||||
@@ -16,6 +19,9 @@ class TabbyConfig:
|
|||||||
developer: dict = {}
|
developer: dict = {}
|
||||||
embeddings: dict = {}
|
embeddings: dict = {}
|
||||||
|
|
||||||
|
# Persistent defaults
|
||||||
|
model_defaults: dict = {}
|
||||||
|
|
||||||
def load(self, arguments: Optional[dict] = None):
|
def load(self, arguments: Optional[dict] = None):
|
||||||
"""Synchronously loads the global application config"""
|
"""Synchronously loads the global application config"""
|
||||||
|
|
||||||
@@ -36,6 +42,14 @@ class TabbyConfig:
|
|||||||
self.developer = unwrap(merged_config.get("developer"), {})
|
self.developer = unwrap(merged_config.get("developer"), {})
|
||||||
self.embeddings = unwrap(merged_config.get("embeddings"), {})
|
self.embeddings = unwrap(merged_config.get("embeddings"), {})
|
||||||
|
|
||||||
|
# Set model defaults dict once to prevent on-demand reconstruction
|
||||||
|
default_keys = unwrap(self.model.get("use_as_default"), [])
|
||||||
|
for key in default_keys:
|
||||||
|
if key in self.model:
|
||||||
|
self.model_defaults[key] = config.model[key]
|
||||||
|
elif key in self.draft_model:
|
||||||
|
self.model_defaults[key] = config.draft_model[key]
|
||||||
|
|
||||||
def _from_file(self, config_path: pathlib.Path):
|
def _from_file(self, config_path: pathlib.Path):
|
||||||
"""loads config from a given file path"""
|
"""loads config from a given file path"""
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ from time import time
|
|||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from common.gen_logging import GenLogPreferences
|
from common.gen_logging import GenLogPreferences
|
||||||
from common.model import get_config_default
|
from common.tabby_config import config
|
||||||
|
from common.utils import unwrap
|
||||||
|
|
||||||
|
|
||||||
class ModelCardParameters(BaseModel):
|
class ModelCardParameters(BaseModel):
|
||||||
@@ -51,23 +52,13 @@ class DraftModelLoadRequest(BaseModel):
|
|||||||
draft_model_name: str
|
draft_model_name: str
|
||||||
|
|
||||||
# Config arguments
|
# Config arguments
|
||||||
draft_rope_scale: Optional[float] = Field(
|
draft_rope_scale: Optional[float] = None
|
||||||
default_factory=lambda: get_config_default(
|
|
||||||
"draft_rope_scale", model_type="draft"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
||||||
description='Automatically calculated if set to "auto"',
|
description='Automatically calculated if set to "auto"',
|
||||||
default_factory=lambda: get_config_default(
|
default=None,
|
||||||
"draft_rope_alpha", model_type="draft"
|
|
||||||
),
|
|
||||||
examples=[1.0],
|
examples=[1.0],
|
||||||
)
|
)
|
||||||
draft_cache_mode: Optional[str] = Field(
|
draft_cache_mode: Optional[str] = None
|
||||||
default_factory=lambda: get_config_default(
|
|
||||||
"draft_cache_mode", model_type="draft"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelLoadRequest(BaseModel):
|
class ModelLoadRequest(BaseModel):
|
||||||
@@ -78,62 +69,45 @@ class ModelLoadRequest(BaseModel):
|
|||||||
|
|
||||||
# Config arguments
|
# Config arguments
|
||||||
|
|
||||||
# Max seq len is fetched from config.json of the model by default
|
|
||||||
max_seq_len: Optional[int] = Field(
|
max_seq_len: Optional[int] = Field(
|
||||||
description="Leave this blank to use the model's base sequence length",
|
description="Leave this blank to use the model's base sequence length",
|
||||||
default_factory=lambda: get_config_default("max_seq_len"),
|
default=None,
|
||||||
examples=[4096],
|
examples=[4096],
|
||||||
)
|
)
|
||||||
override_base_seq_len: Optional[int] = Field(
|
override_base_seq_len: Optional[int] = Field(
|
||||||
description=(
|
description=(
|
||||||
"Overrides the model's base sequence length. " "Leave blank if unsure"
|
"Overrides the model's base sequence length. " "Leave blank if unsure"
|
||||||
),
|
),
|
||||||
default_factory=lambda: get_config_default("override_base_seq_len"),
|
default=None,
|
||||||
examples=[4096],
|
examples=[4096],
|
||||||
)
|
)
|
||||||
cache_size: Optional[int] = Field(
|
cache_size: Optional[int] = Field(
|
||||||
description=("Number in tokens, must be greater than or equal to max_seq_len"),
|
description=("Number in tokens, must be greater than or equal to max_seq_len"),
|
||||||
default_factory=lambda: get_config_default("cache_size"),
|
default=None,
|
||||||
examples=[4096],
|
examples=[4096],
|
||||||
)
|
)
|
||||||
tensor_parallel: Optional[bool] = Field(
|
tensor_parallel: Optional[bool] = None
|
||||||
default_factory=lambda: get_config_default("tensor_parallel")
|
gpu_split_auto: Optional[bool] = None
|
||||||
)
|
autosplit_reserve: Optional[List[float]] = None
|
||||||
gpu_split_auto: Optional[bool] = Field(
|
|
||||||
default_factory=lambda: get_config_default("gpu_split_auto")
|
|
||||||
)
|
|
||||||
autosplit_reserve: Optional[List[float]] = Field(
|
|
||||||
default_factory=lambda: get_config_default("autosplit_reserve")
|
|
||||||
)
|
|
||||||
gpu_split: Optional[List[float]] = Field(
|
gpu_split: Optional[List[float]] = Field(
|
||||||
default_factory=lambda: get_config_default("gpu_split"),
|
default=None,
|
||||||
examples=[[24.0, 20.0]],
|
examples=[[24.0, 20.0]],
|
||||||
)
|
)
|
||||||
rope_scale: Optional[float] = Field(
|
rope_scale: Optional[float] = Field(
|
||||||
description="Automatically pulled from the model's config if not present",
|
description="Automatically pulled from the model's config if not present",
|
||||||
default_factory=lambda: get_config_default("rope_scale"),
|
default=None,
|
||||||
examples=[1.0],
|
examples=[1.0],
|
||||||
)
|
)
|
||||||
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
||||||
description='Automatically calculated if set to "auto"',
|
description='Automatically calculated if set to "auto"',
|
||||||
default_factory=lambda: get_config_default("rope_alpha"),
|
default=None,
|
||||||
examples=[1.0],
|
examples=[1.0],
|
||||||
)
|
)
|
||||||
cache_mode: Optional[str] = Field(
|
cache_mode: Optional[str] = None
|
||||||
default_factory=lambda: get_config_default("cache_mode")
|
chunk_size: Optional[int] = None
|
||||||
)
|
prompt_template: Optional[str] = None
|
||||||
chunk_size: Optional[int] = Field(
|
num_experts_per_token: Optional[int] = None
|
||||||
default_factory=lambda: get_config_default("chunk_size")
|
fasttensors: Optional[bool] = None
|
||||||
)
|
|
||||||
prompt_template: Optional[str] = Field(
|
|
||||||
default_factory=lambda: get_config_default("prompt_template")
|
|
||||||
)
|
|
||||||
num_experts_per_token: Optional[int] = Field(
|
|
||||||
default_factory=lambda: get_config_default("num_experts_per_token")
|
|
||||||
)
|
|
||||||
fasttensors: Optional[bool] = Field(
|
|
||||||
default_factory=lambda: get_config_default("fasttensors")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Non-config arguments
|
# Non-config arguments
|
||||||
draft: Optional[DraftModelLoadRequest] = None
|
draft: Optional[DraftModelLoadRequest] = None
|
||||||
@@ -142,9 +116,11 @@ class ModelLoadRequest(BaseModel):
|
|||||||
|
|
||||||
class EmbeddingModelLoadRequest(BaseModel):
|
class EmbeddingModelLoadRequest(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
# Set default from the config
|
||||||
embeddings_device: Optional[str] = Field(
|
embeddings_device: Optional[str] = Field(
|
||||||
default_factory=lambda: get_config_default(
|
default_factory=lambda: unwrap(
|
||||||
"embeddings_device", model_type="embedding"
|
config.embeddings.get("embeddings_device"), "cpu"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user