Model: Fix inline loading and draft key (#225)

* Model: Fix inline loading and draft key

There was a lack of foresight between the new config.yml and how
it was structured. The "draft" key became "draft_model" without updating
both the API request and inline loading keys.

For the API requests, still support "draft" as legacy, but the "draft_model"
key is preferred.

Signed-off-by: kingbri <bdashore3@proton.me>

* OAI: Add draft model dir to inline load

Was not pushed before and caused errors of the kwargs being None.

Signed-off-by: kingbri <bdashore3@proton.me>

* Model: Fix draft args application

Draft model args weren't applying since there was a reset due to how
the old override behavior worked.

Signed-off-by: kingbri <bdashore3@proton.me>

* OAI: Change embedding model load params

Use embedding_model_name to be inline with the config.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Fix parameter for draft model load

Alias name to draft_model_name.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Fix parameter for template switch

Add prompt_template_name to be more descriptive.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Fix parameter for model load

Alias name to model_name for config parity.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Add alias documentation

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
Brian Dashore
2024-10-24 23:35:05 -04:00
committed by GitHub
parent f20857cb34
commit 6e48bb420a
7 changed files with 68 additions and 46 deletions

View File

@@ -149,8 +149,11 @@ async def load_inline_model(model_name: str, request: Request):
return
# Load the model
await model.load_model(model_path)
# Load the model and also add draft dir
await model.load_model(
model_path,
draft_model=config.draft_model.model_dump(include={"draft_model_dir"}),
)
async def stream_generate_completion(

View File

@@ -123,7 +123,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
"""Loads a model into the model container. This returns an SSE stream."""
# Verify request parameters
if not data.name:
if not data.model_name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
@@ -132,11 +132,11 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / data.name
model_path = model_path / data.model_name
draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
if data.draft_model:
if not data.draft_model.draft_model_name:
error_message = handle_request_error(
"Could not find the draft model name for model load.",
exc_info=False,
@@ -301,7 +301,7 @@ async def load_embedding_model(
request: Request, data: EmbeddingModelLoadRequest
) -> ModelLoadResponse:
# Verify request parameters
if not data.name:
if not data.embedding_model_name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
@@ -310,7 +310,7 @@ async def load_embedding_model(
raise HTTPException(400, error_message)
embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir)
embedding_model_path = embedding_model_dir / data.name
embedding_model_path = embedding_model_dir / data.embedding_model_name
if not embedding_model_path.exists():
error_message = handle_request_error(
@@ -441,7 +441,7 @@ async def list_templates(request: Request) -> TemplateList:
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template."""
if not data.name:
if not data.prompt_template_name:
error_message = handle_request_error(
"New template name not found.",
exc_info=False,
@@ -450,11 +450,12 @@ async def switch_template(data: TemplateSwitchRequest):
raise HTTPException(400, error_message)
try:
template_path = pathlib.Path("templates") / data.name
template_path = pathlib.Path("templates") / data.prompt_template_name
model.container.prompt_template = await PromptTemplate.from_file(template_path)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
f"The template name {data.prompt_template_name} doesn't exist. "
+ "Check the spelling?",
exc_info=False,
).error.message

View File

@@ -1,6 +1,6 @@
"""Contains model card types."""
from pydantic import BaseModel, Field, ConfigDict
from pydantic import AliasChoices, BaseModel, Field, ConfigDict
from time import time
from typing import List, Literal, Optional, Union
@@ -48,7 +48,10 @@ class DraftModelLoadRequest(BaseModel):
"""Represents a draft model load request."""
# Required
draft_model_name: str
draft_model_name: str = Field(
alias=AliasChoices("draft_model_name", "name"),
description="Aliases: name",
)
# Config arguments
draft_rope_scale: Optional[float] = None
@@ -63,8 +66,14 @@ class DraftModelLoadRequest(BaseModel):
class ModelLoadRequest(BaseModel):
"""Represents a model load request."""
# Avoids pydantic namespace warning
model_config = ConfigDict(protected_namespaces=[])
# Required
name: str
model_name: str = Field(
alias=AliasChoices("model_name", "name"),
description="Aliases: name",
)
# Config arguments
@@ -108,12 +117,18 @@ class ModelLoadRequest(BaseModel):
num_experts_per_token: Optional[int] = None
# Non-config arguments
draft: Optional[DraftModelLoadRequest] = None
draft_model: Optional[DraftModelLoadRequest] = Field(
default=None,
alias=AliasChoices("draft_model", "draft"),
)
skip_queue: Optional[bool] = False
class EmbeddingModelLoadRequest(BaseModel):
name: str
embedding_model_name: str = Field(
alias=AliasChoices("embedding_model_name", "name"),
description="Aliases: name",
)
# Set default from the config
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import AliasChoices, BaseModel, Field
from typing import List
@@ -12,4 +12,7 @@ class TemplateList(BaseModel):
class TemplateSwitchRequest(BaseModel):
"""Request to switch a template."""
name: str
prompt_template_name: str = Field(
alias=AliasChoices("prompt_template_name", "name"),
description="Aliases: name",
)

View File

@@ -104,7 +104,7 @@ async def stream_model_load(
# Set the draft model path if it exists
if draft_model_path:
load_data["draft"]["draft_model_dir"] = draft_model_path
load_data["draft_model"]["draft_model_dir"] = draft_model_path
load_status = model.load_model_gen(
model_path, skip_wait=data.skip_queue, **load_data