Model: Add params to current model endpoint

Grabs the current model rope params, max seq len, and the draft model
if applicable.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-10 00:40:56 -05:00
parent 0f4290f05c
commit fd9f3eac87
3 changed files with 31 additions and 6 deletions

25
main.py
View File

@@ -12,7 +12,7 @@ from generators import generate_with_semaphore
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse, ModelCardParameters
from OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
@@ -74,7 +74,25 @@ async def list_models():
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_current_model():
model_name = model_container.get_model_path().name
model_card = ModelCard(id = model_name)
model_card = ModelCard(
id = model_name,
parameters = ModelCardParameters(
rope_scale = model_container.config.scale_pos_emb,
rope_alpha = model_container.config.scale_alpha_value,
max_seq_len = model_container.config.max_seq_len,
)
)
if model_container.draft_config:
draft_card = ModelCard(
id = model_container.get_model_path(True).name,
parameters = ModelCardParameters(
rope_scale = model_container.draft_config.scale_pos_emb,
rope_alpha = model_container.draft_config.scale_alpha_value,
max_seq_len = model_container.draft_config.max_seq_len
)
)
model_card.parameters.draft = draft_card
return model_card
@@ -132,7 +150,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
status="finished"
)
yield get_sse_packet(response.json(ensure_ascii=False))
yield get_sse_packet(response.json(ensure_ascii = False))
# Switch to model progress if the draft model is loaded
if model_container.draft_config:
@@ -345,7 +363,6 @@ if __name__ == "__main__":
config = {}
# If an initial model name is specified, create a container and load the model
model_config = unwrap(config.get("model"), {})
if "model_name" in model_config:
# TODO: Move this to model_container