From 7fded4f18346027b415b69152c4633dcb0baa558 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 14 Mar 2024 10:27:39 -0400 Subject: [PATCH] Tree: Switch to async generators Async generation helps remove many roadblocks to managing tasks using threads. It should allow for abortables and modern-day paradigms. NOTE: Exllamav2 itself is not an asynchronous library. It's just been added into tabby's async nature to allow for a fast and concurrent API server. It's still being debated to run stream_ex in a separate thread or manually manage it using asyncio.sleep(0) Signed-off-by: kingbri --- backends/exllamav2/model.py | 33 +++++++++++++++++------- common/{generators.py => concurrency.py} | 2 +- common/model.py | 6 ++--- common/utils.py | 9 +++++++ endpoints/OAI/app.py | 26 ++++++++----------- endpoints/OAI/utils/chat_completion.py | 32 +++++++++++------------ endpoints/OAI/utils/completion.py | 33 +++++++++++------------- endpoints/OAI/utils/lora.py | 2 +- endpoints/OAI/utils/model.py | 18 +++---------- main.py | 11 +------- 10 files changed, 84 insertions(+), 88 deletions(-) rename common/{generators.py => concurrency.py} (96%) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 77908d8..5cf7b60 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,5 +1,6 @@ """The model container class for ExLlamaV2 models.""" +import asyncio import gc from itertools import zip_longest import pathlib @@ -325,7 +326,7 @@ class ExllamaV2Container: return model_params - def load(self, progress_callback=None): + async def load(self, progress_callback=None): """ Load model @@ -338,7 +339,7 @@ class ExllamaV2Container: for _ in self.load_gen(progress_callback): pass - def load_loras(self, lora_directory: pathlib.Path, **kwargs): + async def load_loras(self, lora_directory: pathlib.Path, **kwargs): """ Load loras """ @@ -361,7 +362,7 @@ class ExllamaV2Container: logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}") lora_path = lora_directory / lora_name - # FIXME(alpin): Does self.model need to be passed here? + self.active_loras.append( ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling) ) @@ -371,7 +372,7 @@ class ExllamaV2Container: # Return success and failure names return {"success": success, "failure": failure} - def load_gen(self, progress_callback=None): + async def load_gen(self, progress_callback=None): """ Load model, generator function @@ -400,12 +401,16 @@ class ExllamaV2Container: logger.info("Loading draft model: " + self.draft_config.model_dir) self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True) - yield from self.draft_model.load_autosplit_gen( + for value in self.draft_model.load_autosplit_gen( self.draft_cache, reserve_vram=autosplit_reserve, last_id_only=True, callback_gen=progress_callback, - ) + ): + # Manually suspend the task to allow for other stuff to run + await asyncio.sleep(0) + if value: + yield value # Test VRAM allocation with a full-length forward pass input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) @@ -424,6 +429,8 @@ class ExllamaV2Container: self.gpu_split, callback_gen=progress_callback, ): + # Manually suspend the task to allow for other stuff to run + await asyncio.sleep(0) if value: yield value @@ -452,6 +459,8 @@ class ExllamaV2Container: last_id_only=True, callback_gen=progress_callback, ): + # Manually suspend the task to allow for other stuff to run + await asyncio.sleep(0) if value: yield value @@ -565,9 +574,11 @@ class ExllamaV2Container: return dict(zip_longest(top_tokens, cleaned_values)) - def generate(self, prompt: str, **kwargs): + async def generate(self, prompt: str, **kwargs): """Generate a response to a prompt""" - generations = list(self.generate_gen(prompt, **kwargs)) + generations = [] + async for generation in self.generate_gen(prompt, **kwargs): + generations.append(generation) joined_generation = { "text": "", @@ -615,8 +626,7 @@ class ExllamaV2Container: return kwargs - # pylint: disable=too-many-locals,too-many-branches,too-many-statements - def generate_gen(self, prompt: str, **kwargs): + async def generate_gen(self, prompt: str, **kwargs): """ Create generator function for prompt completion @@ -889,6 +899,9 @@ class ExllamaV2Container: chunk_tokens = 0 while True: + # Manually suspend the task to allow for other stuff to run + await asyncio.sleep(0) + # Ingest prompt if chunk_tokens == 0: ids = torch.cat((ids, save_tokens), dim=-1) diff --git a/common/generators.py b/common/concurrency.py similarity index 96% rename from common/generators.py rename to common/concurrency.py index 485ca63..6cf9fed 100644 --- a/common/generators.py +++ b/common/concurrency.py @@ -1,4 +1,4 @@ -"""Generator handling""" +"""Concurrency handling""" import asyncio import inspect diff --git a/common/model.py b/common/model.py index de8e6ad..1e1d432 100644 --- a/common/model.py +++ b/common/model.py @@ -52,7 +52,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): progress.start() try: - for module, modules in load_status: + async for module, modules in load_status: if module == 0: loading_task = progress.add_task( f"[cyan]Loading {model_type} modules", total=modules @@ -76,12 +76,12 @@ async def load_model(model_path: pathlib.Path, **kwargs): pass -def load_loras(lora_dir, **kwargs): +async def load_loras(lora_dir, **kwargs): """Wrapper to load loras.""" if len(container.active_loras) > 0: unload_loras() - return container.load_loras(lora_dir, **kwargs) + return await container.load_loras(lora_dir, **kwargs) def unload_loras(): diff --git a/common/utils.py b/common/utils.py index 5ad12a1..dd79e9a 100644 --- a/common/utils.py +++ b/common/utils.py @@ -6,6 +6,8 @@ from loguru import logger from pydantic import BaseModel from typing import Optional +from common.concurrency import release_semaphore + def load_progress(module, modules): """Wrapper callback for load progress.""" @@ -51,6 +53,13 @@ def handle_request_error(message: str, exc_info: bool = True): return request_error +def handle_request_disconnect(message: str): + """Wrapper for handling for request disconnection.""" + + release_semaphore() + logger.error(message) + + def unwrap(wrapped, default=None): """Unwrap function for Optionals.""" if wrapped is None: diff --git a/endpoints/OAI/app.py b/endpoints/OAI/app.py index 09303b0..0571874 100644 --- a/endpoints/OAI/app.py +++ b/endpoints/OAI/app.py @@ -1,7 +1,6 @@ import pathlib import uvicorn from fastapi import FastAPI, Depends, HTTPException, Request -from fastapi.concurrency import run_in_threadpool from fastapi.middleware.cors import CORSMiddleware from functools import partial from loguru import logger @@ -10,7 +9,7 @@ from sys import maxsize from common import config, model, gen_logging, sampling from common.auth import check_admin_key, check_api_key -from common.generators import ( +from common.concurrency import ( call_with_semaphore, generate_with_semaphore, ) @@ -181,9 +180,7 @@ async def load_model(request: Request, data: ModelLoadRequest): if not model_path.exists(): raise HTTPException(400, "model_path does not exist. Check model_name?") - load_callback = partial( - stream_model_load, request, data, model_path, draft_model_path - ) + load_callback = partial(stream_model_load, data, model_path, draft_model_path) # Wrap in a semaphore if the queue isn't being skipped if data.skip_queue: @@ -333,9 +330,7 @@ async def load_lora(data: LoraLoadRequest): "A parent lora directory does not exist. Check your config.yml?", ) - load_callback = partial( - run_in_threadpool, model.load_loras, lora_dir, **data.model_dump() - ) + load_callback = partial(model.load_loras, lora_dir, **data.model_dump()) # Wrap in a semaphore if the queue isn't being skipped if data.skip_queue: @@ -409,9 +404,7 @@ async def completion_request(request: Request, data: CompletionRequest): ) if data.stream and not disable_request_streaming: - generator_callback = partial( - stream_generate_completion, request, data, model_path - ) + generator_callback = partial(stream_generate_completion, data, model_path) return EventSourceResponse( generate_with_semaphore(generator_callback), @@ -452,7 +445,7 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) if data.stream and not disable_request_streaming: generator_callback = partial( - stream_generate_chat_completion, prompt, request, data, model_path + stream_generate_chat_completion, prompt, data, model_path ) return EventSourceResponse( @@ -461,13 +454,13 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) ) else: response = await call_with_semaphore( - partial(generate_chat_completion, prompt, request, data, model_path) + partial(generate_chat_completion, prompt, data, model_path) ) return response -def start_api(host: str, port: int): +async def start_api(host: str, port: int): """Isolated function to start the API server""" # TODO: Move OAI API to a separate folder @@ -475,9 +468,12 @@ def start_api(host: str, port: int): logger.info(f"Completions: http://{host}:{port}/v1/completions") logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions") - uvicorn.run( + config = uvicorn.Config( app, host=host, port=port, log_config=UVICORN_LOG_CONFIG, ) + server = uvicorn.Server(config) + + await server.serve() diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 022ce82..c88c258 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -1,18 +1,21 @@ """Chat completion utilities for OAI server.""" +from asyncio import CancelledError import pathlib from typing import Optional from uuid import uuid4 -from fastapi import HTTPException, Request -from fastapi.concurrency import run_in_threadpool +from fastapi import HTTPException from jinja2 import TemplateError -from loguru import logger from common import model -from common.generators import release_semaphore from common.templating import get_prompt_from_template -from common.utils import get_generator_error, handle_request_error, unwrap +from common.utils import ( + get_generator_error, + handle_request_disconnect, + handle_request_error, + unwrap, +) from endpoints.OAI.types.chat_completion import ( ChatCompletionLogprobs, ChatCompletionLogprob, @@ -150,20 +153,14 @@ def format_prompt_with_template(data: ChatCompletionRequest): async def stream_generate_chat_completion( - prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path + prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path ): """Generator for the generation process.""" try: const_id = f"chatcmpl-{uuid4().hex}" new_generation = model.container.generate_gen(prompt, **data.to_gen_params()) - for generation in new_generation: - # Get out if the request gets disconnected - if await request.is_disconnected(): - release_semaphore() - logger.error("Chat completion generation cancelled by user.") - return - + async for generation in new_generation: response = _create_stream_chunk(const_id, generation, model_path.name) yield response.model_dump_json() @@ -172,6 +169,10 @@ async def stream_generate_chat_completion( finish_response = _create_stream_chunk(const_id, finish_reason="stop") yield finish_response.model_dump_json() + except CancelledError: + # Get out if the request gets disconnected + + handle_request_disconnect("Chat completion generation cancelled by user.") except Exception: yield get_generator_error( "Chat completion aborted. Please check the server console." @@ -179,11 +180,10 @@ async def stream_generate_chat_completion( async def generate_chat_completion( - prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path + prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path ): try: - generation = await run_in_threadpool( - model.container.generate, + generation = await model.container.generate( prompt, **data.to_gen_params(), ) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index ec687fd..7fc993f 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -1,14 +1,17 @@ """Completion utilities for OAI server.""" +from asyncio import CancelledError import pathlib -from fastapi import HTTPException, Request -from fastapi.concurrency import run_in_threadpool -from loguru import logger +from fastapi import HTTPException from typing import Optional from common import model -from common.generators import release_semaphore -from common.utils import get_generator_error, handle_request_error, unwrap +from common.utils import ( + get_generator_error, + handle_request_disconnect, + handle_request_error, + unwrap, +) from endpoints.OAI.types.completion import ( CompletionRequest, CompletionResponse, @@ -57,28 +60,24 @@ def _create_response(generation: dict, model_name: Optional[str]): return response -async def stream_generate_completion( - request: Request, data: CompletionRequest, model_path: pathlib.Path -): +async def stream_generate_completion(data: CompletionRequest, model_path: pathlib.Path): """Streaming generation for completions.""" try: new_generation = model.container.generate_gen( data.prompt, **data.to_gen_params() ) - for generation in new_generation: - # Get out if the request gets disconnected - if await request.is_disconnected(): - release_semaphore() - logger.error("Completion generation cancelled by user.") - return - + async for generation in new_generation: response = _create_response(generation, model_path.name) yield response.model_dump_json() # Yield a finish response on successful generation yield "[DONE]" + except CancelledError: + # Get out if the request gets disconnected + + handle_request_disconnect("Completion generation cancelled by user.") except Exception: yield get_generator_error( "Completion aborted. Please check the server console." @@ -89,9 +88,7 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path) """Non-streaming generate for completions""" try: - generation = await run_in_threadpool( - model.container.generate, data.prompt, **data.to_gen_params() - ) + generation = await model.container.generate(data.prompt, **data.to_gen_params()) response = _create_response(generation, model_path.name) return response diff --git a/endpoints/OAI/utils/lora.py b/endpoints/OAI/utils/lora.py index 809fdea..d00910f 100644 --- a/endpoints/OAI/utils/lora.py +++ b/endpoints/OAI/utils/lora.py @@ -9,6 +9,6 @@ def get_lora_list(lora_path: pathlib.Path): for path in lora_path.iterdir(): if path.is_dir(): lora_card = LoraCard(id=path.name) - lora_list.data.append(lora_card) # pylint: disable=no-member + lora_list.data.append(lora_card) return lora_list diff --git a/endpoints/OAI/utils/model.py b/endpoints/OAI/utils/model.py index b19068b..61d210f 100644 --- a/endpoints/OAI/utils/model.py +++ b/endpoints/OAI/utils/model.py @@ -1,12 +1,9 @@ import pathlib from asyncio import CancelledError -from fastapi import Request -from loguru import logger from typing import Optional from common import model -from common.generators import release_semaphore -from common.utils import get_generator_error +from common.utils import get_generator_error, handle_request_disconnect from endpoints.OAI.types.model import ( ModelCard, @@ -35,7 +32,6 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N async def stream_model_load( - request: Request, data: ModelLoadRequest, model_path: pathlib.Path, draft_model_path: str, @@ -50,14 +46,6 @@ async def stream_model_load( load_status = model.load_model_gen(model_path, **load_data) try: async for module, modules, model_type in load_status: - if await request.is_disconnected(): - release_semaphore() - logger.error( - "Model load cancelled by user. " - "Please make sure to run unload to free up resources." - ) - return - if module != 0: response = ModelLoadResponse( model_type=model_type, @@ -78,7 +66,9 @@ async def stream_model_load( yield response.model_dump_json() except CancelledError: - logger.error( + # Get out if the request gets disconnected + + handle_request_disconnect( "Model load cancelled by user. " "Please make sure to run unload to free up resources." ) diff --git a/main.py b/main.py index fe1e6a2..a3d7c30 100644 --- a/main.py +++ b/main.py @@ -5,9 +5,6 @@ import os import pathlib import signal import sys -import threading -import time -from functools import partial from loguru import logger from typing import Optional @@ -121,13 +118,7 @@ async def entrypoint(args: Optional[dict] = None): lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) model.container.load_loras(lora_dir.resolve(), **lora_config) - # TODO: Replace this with abortables, async via producer consumer, or something else - api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True) - - api_thread.start() - # Keep the program alive - while api_thread.is_alive(): - time.sleep(0.5) + await start_api(host, port) if __name__ == "__main__":