diff --git a/generators.py b/generators.py index 287835c..285f26a 100644 --- a/generators.py +++ b/generators.py @@ -1,4 +1,6 @@ +import inspect from asyncio import Semaphore +from functools import partialmethod from typing import AsyncGenerator generate_semaphore = Semaphore(1) @@ -6,5 +8,16 @@ generate_semaphore = Semaphore(1) # Async generation that blocks on a semaphore async def generate_with_semaphore(generator: AsyncGenerator): async with generate_semaphore: - async for result in generator(): - yield result + if inspect.isasyncgenfunction: + async for result in generator(): + yield result + else: + for result in generator(): + yield result + +# Block a function with semaphore +async def call_with_semaphore(callback: partialmethod): + if inspect.iscoroutinefunction(callback): + return await callback() + async with generate_semaphore: + return callback() diff --git a/main.py b/main.py index 955a5ff..b0b6a06 100644 --- a/main.py +++ b/main.py @@ -5,13 +5,14 @@ from asyncio import CancelledError from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse +from functools import partial from progress.bar import IncrementalBar from typing import Optional from uuid import uuid4 import gen_logging from auth import check_admin_key, check_api_key, load_auth_keys -from generators import generate_with_semaphore +from generators import call_with_semaphore, generate_with_semaphore from model import ModelContainer from OAI.types.completion import CompletionRequest from OAI.types.chat_completion import ChatCompletionRequest @@ -108,7 +109,6 @@ async def list_draft_models(): draft_model_path = pathlib.Path(draft_model_dir) models = get_model_list(draft_model_path.resolve()) - print(models) return models @@ -302,7 +302,9 @@ async def generate_completion(request: Request, data: CompletionRequest): media_type = "text/event-stream" ) else: - response_text, prompt_tokens, completion_tokens = model_container.generate(data.prompt, **data.to_gen_params()) + response_text, prompt_tokens, completion_tokens = await call_with_semaphore( + partial(model_container.generate, data.prompt, **data.to_gen_params()) + ) response = create_completion_response(response_text, prompt_tokens, completion_tokens, @@ -367,7 +369,9 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest media_type = "text/event-stream" ) else: - response_text, prompt_tokens, completion_tokens = model_container.generate(prompt, **data.to_gen_params()) + response_text, prompt_tokens, completion_tokens = await call_with_semaphore( + partial(model_container.generate, prompt, **data.to_gen_params()) + ) response = create_chat_completion_response(response_text, prompt_tokens, completion_tokens, diff --git a/model.py b/model.py index a5d4e2b..fe091c5 100644 --- a/model.py +++ b/model.py @@ -129,7 +129,7 @@ class ModelContainer: ) # If that fails, attempt fetching from model name - if self.prompt_template == None: + if self.prompt_template is None: template_match = find_template_from_model(model_directory) if template_match: self.prompt_template = get_template_from_file(template_match)