Generator: Fix semaphore scheduling

Non-streaming tasks were not regulated by the semaphore, causing these
tasks to interfere with streaming generations. Add helper functions
to take in both sync and async functions for callbacks and sequential
blocking with the semaphore.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-21 21:39:45 -05:00
parent bee758dae9
commit 1a8afcb6ad
3 changed files with 24 additions and 7 deletions

View File

@@ -1,4 +1,6 @@
import inspect
from asyncio import Semaphore
from functools import partialmethod
from typing import AsyncGenerator
generate_semaphore = Semaphore(1)
@@ -6,5 +8,16 @@ generate_semaphore = Semaphore(1)
# Async generation that blocks on a semaphore
async def generate_with_semaphore(generator: AsyncGenerator):
async with generate_semaphore:
async for result in generator():
yield result
if inspect.isasyncgenfunction:
async for result in generator():
yield result
else:
for result in generator():
yield result
# Block a function with semaphore
async def call_with_semaphore(callback: partialmethod):
if inspect.iscoroutinefunction(callback):
return await callback()
async with generate_semaphore:
return callback()

12
main.py
View File

@@ -5,13 +5,14 @@ from asyncio import CancelledError
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from functools import partial
from progress.bar import IncrementalBar
from typing import Optional
from uuid import uuid4
import gen_logging
from auth import check_admin_key, check_api_key, load_auth_keys
from generators import generate_with_semaphore
from generators import call_with_semaphore, generate_with_semaphore
from model import ModelContainer
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
@@ -108,7 +109,6 @@ async def list_draft_models():
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
print(models)
return models
@@ -302,7 +302,9 @@ async def generate_completion(request: Request, data: CompletionRequest):
media_type = "text/event-stream"
)
else:
response_text, prompt_tokens, completion_tokens = model_container.generate(data.prompt, **data.to_gen_params())
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
partial(model_container.generate, data.prompt, **data.to_gen_params())
)
response = create_completion_response(response_text,
prompt_tokens,
completion_tokens,
@@ -367,7 +369,9 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
media_type = "text/event-stream"
)
else:
response_text, prompt_tokens, completion_tokens = model_container.generate(prompt, **data.to_gen_params())
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
partial(model_container.generate, prompt, **data.to_gen_params())
)
response = create_chat_completion_response(response_text,
prompt_tokens,
completion_tokens,

View File

@@ -129,7 +129,7 @@ class ModelContainer:
)
# If that fails, attempt fetching from model name
if self.prompt_template == None:
if self.prompt_template is None:
template_match = find_template_from_model(model_directory)
if template_match:
self.prompt_template = get_template_from_file(template_match)