Model: Fix inline loading and draft key (#225)

* Model: Fix inline loading and draft key

There was a lack of foresight between the new config.yml and how
it was structured. The "draft" key became "draft_model" without updating
both the API request and inline loading keys.

For the API requests, still support "draft" as legacy, but the "draft_model"
key is preferred.

Signed-off-by: kingbri <bdashore3@proton.me>

* OAI: Add draft model dir to inline load

Was not pushed before and caused errors of the kwargs being None.

Signed-off-by: kingbri <bdashore3@proton.me>

* Model: Fix draft args application

Draft model args weren't applying since there was a reset due to how
the old override behavior worked.

Signed-off-by: kingbri <bdashore3@proton.me>

* OAI: Change embedding model load params

Use embedding_model_name to be inline with the config.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Fix parameter for draft model load

Alias name to draft_model_name.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Fix parameter for template switch

Add prompt_template_name to be more descriptive.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Fix parameter for model load

Alias name to model_name for config parity.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Add alias documentation

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
Brian Dashore
2024-10-24 23:35:05 -04:00
committed by GitHub
parent f20857cb34
commit 6e48bb420a
7 changed files with 68 additions and 46 deletions

View File

@@ -129,8 +129,27 @@ class ExllamaV2Container:
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()
# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft"), {})
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name
@@ -154,25 +173,6 @@ class ExllamaV2Container:
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)
# MARK: User configuration
# Get cache mode
@@ -338,9 +338,6 @@ class ExllamaV2Container:
# Set user-configured draft model values
if enable_draft:
# Fetch from the updated kwargs
draft_args = unwrap(kwargs.get("draft"), {})
self.draft_config.max_seq_len = self.config.max_seq_len
self.draft_config.scale_pos_emb = unwrap(
@@ -384,9 +381,12 @@ class ExllamaV2Container:
override_args = unwrap(yaml.load(contents), {})
# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
if self.draft_config and draft_override_args:
kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")}
draft_override_args = unwrap(override_args.get("draft_model"), {})
if draft_override_args:
kwargs["draft_model"] = {
**draft_override_args,
**unwrap(kwargs.get("draft_model"), {}),
}
# Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs}