mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Generator: Fix semaphore scheduling
Non-streaming tasks were not regulated by the semaphore, causing these tasks to interfere with streaming generations. Add helper functions to take in both sync and async functions for callbacks and sequential blocking with the semaphore. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import inspect
|
||||
from asyncio import Semaphore
|
||||
from functools import partialmethod
|
||||
from typing import AsyncGenerator
|
||||
|
||||
generate_semaphore = Semaphore(1)
|
||||
@@ -6,5 +8,16 @@ generate_semaphore = Semaphore(1)
|
||||
# Async generation that blocks on a semaphore
|
||||
async def generate_with_semaphore(generator: AsyncGenerator):
|
||||
async with generate_semaphore:
|
||||
async for result in generator():
|
||||
yield result
|
||||
if inspect.isasyncgenfunction:
|
||||
async for result in generator():
|
||||
yield result
|
||||
else:
|
||||
for result in generator():
|
||||
yield result
|
||||
|
||||
# Block a function with semaphore
|
||||
async def call_with_semaphore(callback: partialmethod):
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
return await callback()
|
||||
async with generate_semaphore:
|
||||
return callback()
|
||||
|
||||
12
main.py
12
main.py
@@ -5,13 +5,14 @@ from asyncio import CancelledError
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from functools import partial
|
||||
from progress.bar import IncrementalBar
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import gen_logging
|
||||
from auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from generators import generate_with_semaphore
|
||||
from generators import call_with_semaphore, generate_with_semaphore
|
||||
from model import ModelContainer
|
||||
from OAI.types.completion import CompletionRequest
|
||||
from OAI.types.chat_completion import ChatCompletionRequest
|
||||
@@ -108,7 +109,6 @@ async def list_draft_models():
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
print(models)
|
||||
|
||||
return models
|
||||
|
||||
@@ -302,7 +302,9 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
||||
media_type = "text/event-stream"
|
||||
)
|
||||
else:
|
||||
response_text, prompt_tokens, completion_tokens = model_container.generate(data.prompt, **data.to_gen_params())
|
||||
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
|
||||
partial(model_container.generate, data.prompt, **data.to_gen_params())
|
||||
)
|
||||
response = create_completion_response(response_text,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
@@ -367,7 +369,9 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
||||
media_type = "text/event-stream"
|
||||
)
|
||||
else:
|
||||
response_text, prompt_tokens, completion_tokens = model_container.generate(prompt, **data.to_gen_params())
|
||||
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
|
||||
partial(model_container.generate, prompt, **data.to_gen_params())
|
||||
)
|
||||
response = create_chat_completion_response(response_text,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
|
||||
2
model.py
2
model.py
@@ -129,7 +129,7 @@ class ModelContainer:
|
||||
)
|
||||
|
||||
# If that fails, attempt fetching from model name
|
||||
if self.prompt_template == None:
|
||||
if self.prompt_template is None:
|
||||
template_match = find_template_from_model(model_directory)
|
||||
if template_match:
|
||||
self.prompt_template = get_template_from_file(template_match)
|
||||
|
||||
Reference in New Issue
Block a user