mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-24 16:29:18 +00:00
Model + API: Migrate to use BaseSamplerParams
kwargs is pretty ugly when figuring out which arguments to use. The base requests falls back to defaults anyways, so pass in the params object as is. However, since Python's typing isn't like TypeScript where types can be transformed, the type hinting has a possiblity of None showing up despite there always being a value for some params. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -8,12 +8,12 @@ import asyncio
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from fastapi import HTTPException, Request
|
||||
from typing import List, Union
|
||||
|
||||
from loguru import logger
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from common import model
|
||||
from common.auth import get_key_permission
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
@@ -86,16 +86,21 @@ def _create_response(
|
||||
async def _stream_collector(
|
||||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: CompletionRequest,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
):
|
||||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
abort_event,
|
||||
mm_embeddings,
|
||||
)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
@@ -195,10 +200,10 @@ async def stream_generate_completion(
|
||||
_stream_collector(
|
||||
n,
|
||||
gen_queue,
|
||||
data.prompt,
|
||||
request.state.id,
|
||||
data.prompt,
|
||||
task_gen_params,
|
||||
abort_event,
|
||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -256,9 +261,9 @@ async def generate_completion(
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
data.prompt,
|
||||
request.state.id,
|
||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||
data.prompt,
|
||||
task_gen_params,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user