Model + API: Migrate to use BaseSamplerParams

kwargs is pretty ugly when figuring out which arguments to use. The
base requests falls back to defaults anyways, so pass in the params
object as is.

However, since Python's typing isn't like TypeScript where types
can be transformed, the type hinting has a possiblity of None showing
up despite there always being a value for some params.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri
2025-04-16 00:50:05 -04:00
parent dcb36e9ab2
commit 3084ef9fa1
5 changed files with 113 additions and 121 deletions

View File

@@ -8,12 +8,12 @@ import asyncio
import pathlib
from asyncio import CancelledError
from fastapi import HTTPException, Request
from typing import List, Union
from loguru import logger
from typing import List, Optional, Union
from common import model
from common.auth import get_key_permission
from common.multimodal import MultimodalEmbeddingWrapper
from common.networking import (
get_generator_error,
handle_request_disconnect,
@@ -86,16 +86,21 @@ def _create_response(
async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
request_id: str,
prompt: str,
params: CompletionRequest,
abort_event: asyncio.Event,
**kwargs,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""Collects a stream and places results in a common queue"""
try:
new_generation = model.container.generate_gen(
prompt, request_id, abort_event, **kwargs
request_id,
prompt,
params,
abort_event,
mm_embeddings,
)
async for generation in new_generation:
generation["index"] = task_idx
@@ -195,10 +200,10 @@ async def stream_generate_completion(
_stream_collector(
n,
gen_queue,
data.prompt,
request.state.id,
data.prompt,
task_gen_params,
abort_event,
**task_gen_params.model_dump(exclude={"prompt"}),
)
)
@@ -256,9 +261,9 @@ async def generate_completion(
gen_tasks.append(
asyncio.create_task(
model.container.generate(
data.prompt,
request.state.id,
**task_gen_params.model_dump(exclude={"prompt"}),
data.prompt,
task_gen_params,
)
)
)