mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
API + Model: Fix application of defaults
use_as_default was not being properly applied into model overrides. For compartmentalization's sake, apply all overrides in a single function to avoid clutter. In addition, fix where the traditional /v1/model/load endpoint checks for draft options. These can be applied via an inline config, so let any failures fallthrough. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -68,10 +68,14 @@ def detect_backend(hf_model: HFModel) -> str:
|
|||||||
return "exllamav2"
|
return "exllamav2"
|
||||||
|
|
||||||
|
|
||||||
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
async def apply_load_defaults(model_path: pathlib.Path, **kwargs):
|
||||||
"""Sets overrides from a model folder's config yaml."""
|
"""
|
||||||
|
Applies model load overrides.
|
||||||
|
Sources are from inline config and use_as_default.
|
||||||
|
Currently agnostic due to different schemas for API and config.
|
||||||
|
"""
|
||||||
|
|
||||||
override_config_path = model_dir / "tabby_config.yml"
|
override_config_path = model_path / "tabby_config.yml"
|
||||||
|
|
||||||
if not override_config_path.exists():
|
if not override_config_path.exists():
|
||||||
return kwargs
|
return kwargs
|
||||||
@@ -88,20 +92,23 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
|||||||
yaml = YAML(typ="safe")
|
yaml = YAML(typ="safe")
|
||||||
inline_config = unwrap(yaml.load(contents), {})
|
inline_config = unwrap(yaml.load(contents), {})
|
||||||
|
|
||||||
# Check for inline model overrides
|
# Check for inline model overrides and merge config defaults
|
||||||
model_inline_config = unwrap(inline_config.get("model"), {})
|
model_inline_config = unwrap(inline_config.get("model"), {})
|
||||||
if model_inline_config:
|
if model_inline_config:
|
||||||
overrides = {**model_inline_config}
|
overrides = {**model_inline_config, **config.model_defaults}
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Cannot find inline model overrides. "
|
"Cannot find inline model overrides. "
|
||||||
'Make sure they are nested under a "model:" key'
|
'Make sure they are nested under a "model:" key'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge draft overrides beforehand
|
# Merge draft overrides beforehand and merge config defaults
|
||||||
draft_inline_config = unwrap(inline_config.get("draft_model"), {})
|
draft_inline_config = unwrap(inline_config.get("draft_model"), {})
|
||||||
if draft_inline_config:
|
if draft_inline_config:
|
||||||
overrides["draft_model"] = {**draft_inline_config}
|
overrides["draft_model"] = {
|
||||||
|
**draft_inline_config,
|
||||||
|
**config.draft_model_defaults,
|
||||||
|
}
|
||||||
|
|
||||||
# Merge the override and model kwargs
|
# Merge the override and model kwargs
|
||||||
# No need to preserve the original overrides dict
|
# No need to preserve the original overrides dict
|
||||||
@@ -143,8 +150,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||||||
# Merge with config and inline defaults
|
# Merge with config and inline defaults
|
||||||
# TODO: Figure out a way to do this with Pydantic validation
|
# TODO: Figure out a way to do this with Pydantic validation
|
||||||
# and ModelLoadRequest. Pydantic doesn't have async validators
|
# and ModelLoadRequest. Pydantic doesn't have async validators
|
||||||
kwargs = {**config.model_defaults, **kwargs}
|
kwargs = await apply_load_defaults(model_path, **kwargs)
|
||||||
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
|
||||||
|
|
||||||
# Fetch the extra HF configuration options
|
# Fetch the extra HF configuration options
|
||||||
hf_model = await HFModel.from_directory(model_path)
|
hf_model = await HFModel.from_directory(model_path)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class TabbyConfig(TabbyConfigModel):
|
|||||||
# Persistent defaults
|
# Persistent defaults
|
||||||
# TODO: make this pydantic?
|
# TODO: make this pydantic?
|
||||||
model_defaults: dict = {}
|
model_defaults: dict = {}
|
||||||
|
draft_model_defaults: dict = {}
|
||||||
|
|
||||||
def load(self, arguments: Optional[dict] = None):
|
def load(self, arguments: Optional[dict] = None):
|
||||||
"""Synchronously loads the global application config"""
|
"""Synchronously loads the global application config"""
|
||||||
@@ -50,7 +51,7 @@ class TabbyConfig(TabbyConfigModel):
|
|||||||
if hasattr(self.model, field):
|
if hasattr(self.model, field):
|
||||||
self.model_defaults[field] = getattr(config.model, field)
|
self.model_defaults[field] = getattr(config.model, field)
|
||||||
elif hasattr(self.draft_model, field):
|
elif hasattr(self.draft_model, field):
|
||||||
self.model_defaults[field] = getattr(config.draft_model, field)
|
self.draft_model_defaults[field] = getattr(config.draft_model, field)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"invalid item {field} in config option `model.use_as_default`"
|
f"invalid item {field} in config option `model.use_as_default`"
|
||||||
|
|||||||
@@ -193,18 +193,6 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
|||||||
model_path = pathlib.Path(config.model.model_dir)
|
model_path = pathlib.Path(config.model.model_dir)
|
||||||
model_path = model_path / data.model_name
|
model_path = model_path / data.model_name
|
||||||
|
|
||||||
draft_model_path = None
|
|
||||||
if data.draft_model:
|
|
||||||
if not data.draft_model.draft_model_name:
|
|
||||||
error_message = handle_request_error(
|
|
||||||
"Could not find the draft model name for model load.",
|
|
||||||
exc_info=False,
|
|
||||||
).error.message
|
|
||||||
|
|
||||||
raise HTTPException(400, error_message)
|
|
||||||
|
|
||||||
draft_model_path = config.draft_model.draft_model_dir
|
|
||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
"Could not find the model path for load. Check model name or config.yml?",
|
"Could not find the model path for load. Check model name or config.yml?",
|
||||||
@@ -213,9 +201,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
|||||||
|
|
||||||
raise HTTPException(400, error_message)
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(stream_model_load(data, model_path), ping=maxsize)
|
||||||
stream_model_load(data, model_path, draft_model_path), ping=maxsize
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Unload model endpoint
|
# Unload model endpoint
|
||||||
|
|||||||
@@ -77,16 +77,16 @@ def get_dummy_models():
|
|||||||
async def stream_model_load(
|
async def stream_model_load(
|
||||||
data: ModelLoadRequest,
|
data: ModelLoadRequest,
|
||||||
model_path: pathlib.Path,
|
model_path: pathlib.Path,
|
||||||
draft_model_path: str,
|
|
||||||
):
|
):
|
||||||
"""Request generation wrapper for the loading process."""
|
"""Request generation wrapper for the loading process."""
|
||||||
|
|
||||||
# Get trimmed load data
|
# Get trimmed load data
|
||||||
load_data = data.model_dump(exclude_none=True)
|
load_data = data.model_dump(exclude_none=True)
|
||||||
|
|
||||||
# Set the draft model path if it exists
|
# Set the draft model directory
|
||||||
if draft_model_path:
|
load_data.setdefault("draft_model", {})["draft_model_dir"] = (
|
||||||
load_data["draft_model"]["draft_model_dir"] = draft_model_path
|
config.draft_model.draft_model_dir
|
||||||
|
)
|
||||||
|
|
||||||
load_status = model.load_model_gen(
|
load_status = model.load_model_gen(
|
||||||
model_path, skip_wait=data.skip_queue, **load_data
|
model_path, skip_wait=data.skip_queue, **load_data
|
||||||
|
|||||||
Reference in New Issue
Block a user