diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 1b4f0b6..cb12a15 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -145,9 +145,6 @@ class ExllamaV2Container(BaseModelContainer): "Skipping generation config load because of an unexpected error." ) - # Apply a model's config overrides while respecting user settings - kwargs = await self.set_model_overrides(**kwargs) - # Set vision state and error if vision isn't supported on the current model self.use_vision = unwrap(kwargs.get("vision"), False) if self.use_vision and not self.config.vision_model_type: @@ -385,35 +382,6 @@ class ExllamaV2Container(BaseModelContainer): # Return the created instance return self - async def set_model_overrides(self, **kwargs): - """Sets overrides from a model folder's config yaml.""" - - override_config_path = self.model_dir / "tabby_config.yml" - - if not override_config_path.exists(): - return kwargs - - async with aiofiles.open( - override_config_path, "r", encoding="utf8" - ) as override_config_file: - contents = await override_config_file.read() - - # Create a temporary YAML parser - yaml = YAML(typ="safe") - override_args = unwrap(yaml.load(contents), {}) - - # Merge draft overrides beforehand - draft_override_args = unwrap(override_args.get("draft_model"), {}) - if draft_override_args: - kwargs["draft_model"] = { - **draft_override_args, - **unwrap(kwargs.get("draft_model"), {}), - } - - # Merge the override and model kwargs - merged_kwargs = {**override_args, **kwargs} - return merged_kwargs - async def find_prompt_template(self, prompt_template_name, model_directory): """Tries to find a prompt template using various methods.""" diff --git a/common/model.py b/common/model.py index f4eeeee..2908fc3 100644 --- a/common/model.py +++ b/common/model.py @@ -4,10 +4,12 @@ Manages the storage and utility of model containers. Containers exist as a common interface for backends. """ +import aiofiles import pathlib from enum import Enum from fastapi import HTTPException from loguru import logger +from ruamel.yaml import YAML from typing import Optional from backends.base_model_container import BaseModelContainer @@ -15,6 +17,7 @@ from common.logger import get_loading_progress_bar from common.networking import handle_request_error from common.tabby_config import config from common.optional_dependencies import dependencies +from common.utils import unwrap # Global variables for model container container: Optional[BaseModelContainer] = None @@ -43,6 +46,37 @@ def load_progress(module, modules): yield module, modules +# TODO: Change this to be inline with config.yml +async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs): + """Sets overrides from a model folder's config yaml.""" + + override_config_path = model_dir / "tabby_config.yml" + + if not override_config_path.exists(): + return kwargs + + async with aiofiles.open( + override_config_path, "r", encoding="utf8" + ) as override_config_file: + contents = await override_config_file.read() + + # Create a temporary YAML parser + yaml = YAML(typ="safe") + override_args = unwrap(yaml.load(contents), {}) + + # Merge draft overrides beforehand + draft_override_args = unwrap(override_args.get("draft_model"), {}) + if draft_override_args: + kwargs["draft_model"] = { + **draft_override_args, + **unwrap(kwargs.get("draft_model"), {}), + } + + # Merge the override and model kwargs + merged_kwargs = {**override_args, **kwargs} + return merged_kwargs + + async def unload_model(skip_wait: bool = False, shutdown: bool = False): """Unloads a model""" global container @@ -70,8 +104,15 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): # Reset to prepare for a new container container = None - # Merge with config defaults + # Model_dir is already provided + # TODO: Isolate the root cause + kwargs.pop("model_dir") + + # Merge with config and inline defaults kwargs = {**config.model_defaults, **kwargs} + kwargs = await apply_inline_overrides(model_path, **kwargs) + + print(kwargs) # Create a new container new_container = await ExllamaV2Container.create(